In [9]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModel
import os
import re
import sys
from tqdm import tqdm
from safetensors.torch import load_file as safe_load_file

In [None]:
# ====================================================
# 1.CONFIGURATION
# ====================================================
class Config:
    try:
        BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        BASE_DIR = os.getcwd()
        if "Subtask1" not in BASE_DIR and os.path.exists(os.path.join(BASE_DIR, "Subtask1")):
            BASE_DIR = os.path.join(BASE_DIR, "Subtask1")

    print(f"Working Directory: {BASE_DIR}")
    
    WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
    DATA_DIR = os.path.join(BASE_DIR, "data")
    OUTPUT_FILE = os.path.join(BASE_DIR, "submission_subtask1.csv")

    base_model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest"
    max_seq_length = 512
    batch_size = 32
    window_size = 4
    device = "cuda" if torch.cuda.is_available() else "cpu"

    num_experts = 4
    top_k = 2

In [11]:
# ============================================================
# 2. MODEL ARCHITECTURE (SOFT MOE + MEAN POOLING)
# ============================================================
class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

class SoftMoEHead(nn.Module):
    def __init__(self, hidden_size, num_experts=4, output_dim=1):
        super().__init__()
        self.gate = nn.Linear(hidden_size, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(hidden_size, output_dim)
            ) for _ in range(num_experts)
        ])

    def forward(self, x):
        gate_logits = self.gate(x)
        gate_weights = torch.softmax(gate_logits, dim=1)
        expert_outputs = torch.stack([exp(x) for exp in self.experts], dim=1)
        # Weighted Sum
        output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs, dim=1)
        return output

class Subtask1Model(nn.Module):
    def __init__(self, model_name, num_experts=4):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name)
        self.backbone = AutoModel.from_pretrained(model_name)
        hidden_size = self.config.hidden_size

        self.pooler = MeanPooling()
        self.valence_moe = SoftMoEHead(hidden_size, num_experts=num_experts, output_dim=1)
        self.arousal_moe = SoftMoEHead(hidden_size, num_experts=num_experts, output_dim=1)

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        feature_vector = self.pooler(outputs.last_hidden_state, attention_mask)

        val_pred = self.valence_moe(feature_vector)
        aro_pred = self.arousal_moe(feature_vector)

        return torch.cat((val_pred, aro_pred), dim=1)

# ============================================================
# 3. DATA PROCESSING
# ============================================================
def fix_spacing(text):
    text = str(text)
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"\s+'\s+", "'", text)
    return text.strip()

def prepare_test_data(df, window_size):
    print(f"   -> Processing Sliding Window (Size={window_size})...")
    
    # Auto-fix column names
    possible_names = ['tweet', 'content', 'post', 'sentence', 'message']
    for col in possible_names:
        if col in df.columns and 'text' not in df.columns:
            print(f"Renaming column '{col}' -> 'text'")
            df = df.rename(columns={col: 'text'})
            
    df['text'] = df['text'].apply(fix_spacing)
    if 'timestamp' in df.columns:
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.sort_values(by=['user_id', 'timestamp']).reset_index(drop=True)
    
    new_data = []
    # Xác định cột ID (ưu tiên text_id, nếu không có thì lấy index hoặc id)
    id_col = 'text_id' if 'text_id' in df.columns else ('id' if 'id' in df.columns else None)
    
    for uid, group in df.groupby('user_id'):
        texts = group['text'].values
        ids = group[id_col].values if id_col else range(len(texts))
        
        for i in range(len(texts)):
            context_list = []
            # Sliding Window Logic
            for k in range(1, window_size):
                prev_idx = i - k
                if prev_idx >= 0:
                    context_list.insert(0, str(texts[prev_idx]))
            
            current_text = str(texts[i])
            if len(context_list) > 0:
                context_str = ' </s> '.join(context_list)
                full_input = f"{context_str} </s> {current_text}"
            else:
                full_input = current_text
            
            new_data.append({
                'user_id': uid,
                'text_id': ids[i],
                'input_text': full_input
            })
            
    return pd.DataFrame(new_data)

class InferenceDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.texts = df['input_text'].values
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self): return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            str(self.texts[idx]),
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            'input_ids': enc['input_ids'].flatten(),
            'attention_mask': enc['attention_mask'].flatten()
        }

In [None]:
# ============================================================
# 4. MAIN INFERENCE
# ============================================================
def predict():
    print("="*60)
    print("STARTING INFERENCE SUBTASK 1 (Soft MoE)")
    print(f"Working Directory: {Config.BASE_DIR}")
    print("="*60)

    actual_weights_path = Config.WEIGHTS_DIR
    if os.path.exists(Config.WEIGHTS_DIR):
        files = os.listdir(Config.WEIGHTS_DIR)
        has_model = any(f.endswith(".bin") or f.endswith(".safetensors") for f in files)
        if not has_model:
            subfolders = [f for f in files if os.path.isdir(os.path.join(Config.WEIGHTS_DIR, f))]
            if subfolders:
                print(f"Detected nested folder. Going into: {subfolders[0]}")
                actual_weights_path = os.path.join(Config.WEIGHTS_DIR, subfolders[0])
            else:
                print("ERROR: Weights folder empty!"); return
    print(f"Target Weights Path: {actual_weights_path}")

    # --- [STEP 1] LOAD DATA (RECURSIVE SEARCH) ---
    print(">>> [1/4] Looking for Data Files...")
    if not os.path.exists(Config.DATA_DIR): print(f"ERROR: No data folder!"); return

    test_file_path = None
    
    for root, dirs, files in os.walk(Config.DATA_DIR):
        candidates = [f for f in files if ("test" in f.lower() or "subtask1" in f.lower()) 
                      and "train" not in f.lower() and "val" not in f.lower() and f.endswith(".csv")]
        
        if candidates:
            test_file_path = os.path.join(root, candidates[0])
            print(f" Found Test File: {test_file_path}")
            break
    
    if not test_file_path:
        return

    # --- [STEP 2] LOAD TOKENIZER & MODEL ---
    print(">>> [2/4] Loading Tokenizer & Model...")
    tokenizer = None
    try:
        tokenizer = AutoTokenizer.from_pretrained(actual_weights_path, local_files_only=True, use_fast=False)
        print("Loaded tokenizer (Local/Slow).")
    except Exception as e:
        print(f"Local tokenizer failed: {e}. Downloading base...")
        tokenizer = AutoTokenizer.from_pretrained(Config.base_model_name, use_fast=False)

    try:
        model = Subtask1Model(Config.base_model_name, num_experts=Config.num_experts)
    except Exception as e:
        print(f"Error init model: {e}"); return

    w_files = [f for f in os.listdir(actual_weights_path) if f.endswith('.safetensors') or f.endswith('.bin')]
    if not w_files: print("No model file found!"); return
    w_path = os.path.join(actual_weights_path, w_files[0])
    print(f"   Loading weights from: {w_files[0]}")
    
    if w_path.endswith(".safetensors"):
        model.load_state_dict(safe_load_file(w_path), strict=False)
    else:
        model.load_state_dict(torch.load(w_path, map_location="cpu"), strict=False)

    model.to(Config.device).eval()

    # --- [STEP 3] PROCESS DATA & PREDICT ---
    print(">>> [3/4] Processing & Predicting...")
    df_raw = pd.read_csv(test_file_path)
    df_proc = prepare_test_data(df_raw, Config.window_size)
    
    test_ds = InferenceDataset(df_proc, tokenizer, Config.max_seq_length)
    test_loader = DataLoader(test_ds, batch_size=Config.batch_size, shuffle=False)

    all_preds = []
    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(Config.device)
            attention_mask = batch['attention_mask'].to(Config.device)
            
            logits = model(input_ids, attention_mask)
            all_preds.extend(logits.cpu().numpy())
            
    all_preds = np.array(all_preds)

    # --- [STEP 4] POST-PROCESSING & SAVE ---
    print(">>> [4/4] Clipping & Saving...")
    
    all_preds[:, 0] = np.clip(all_preds[:, 0], -2.0, 2.0)
    all_preds[:, 1] = np.clip(all_preds[:, 1], 0.0, 2.0)

    submission = pd.DataFrame({
        'user_id': df_proc['user_id'],
        'text_id': df_proc['text_id'],
        'pred_valence': all_preds[:, 0],
        'pred_arousal': all_preds[:, 1]
    })
    
    submission.to_csv(Config.OUTPUT_FILE, index=False)
    print(f"DONE! Saved to: {Config.OUTPUT_FILE}")
    print(submission.head())

if __name__ == "__main__":
    predict()