In [None]:

import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModel, AutoTokenizer, logging
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight
from tqdm.auto import tqdm

logging.set_verbosity_error()
import warnings
warnings.filterwarnings("ignore")




class Config:
    # Data params - 
    CSV_PATH = 'DATA_PATH' # modify as the actul data path      
    TEXT_COL = 'text'          
    LABEL_COL = 'sentiment'        
    THREAD_COL = 'thread_id'   
    TURN_COL = 'turn_index'    
    PARENT_COL = 'parent_id'   
    
    # Model defaults
    PLM_NAME = 'roberta-base'
    MAX_LEN = 256             # Matched to baseline
    MAX_TURNS = 50
    
    # Architecture Params
    PROJ_DIM = 256
    HIDDEN_DIM = 256          # Matched to baseline
    DROPOUT = 0.3             # Matched to baseline
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    SEED = 42

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(Config.SEED)
print(f"Device set to: {Config.DEVICE}")


## 2. Data Loading
def load_data_split(config):
    if not os.path.exists(config.CSV_PATH):
         raise FileNotFoundError(f"CSV file not found at {config.CSV_PATH}")

    df = pd.read_csv(config.CSV_PATH)
    df.columns = df.columns.str.strip()
    
    unique_labels = sorted(df[config.LABEL_COL].unique())
    label_map = {l: i for i, l in enumerate(unique_labels)}
    df['label_idx'] = df[config.LABEL_COL].map(label_map)
    
    threads = [group for _, group in df.groupby(config.THREAD_COL)]
    train, test = train_test_split(threads, test_size=0.3, random_state=config.SEED)
    val, test = train_test_split(test, test_size=0.5, random_state=config.SEED)
    
    all_labels = []
    for t in train: all_labels.extend(t['label_idx'].values)
    
    if len(np.unique(all_labels)) < len(label_map):
        weights = None
    else:
        cw = compute_class_weight('balanced', classes=np.unique(all_labels), y=all_labels)
        weights = torch.tensor(cw, dtype=torch.float).to(config.DEVICE)
        print(f"Class Weights: {weights}")
    
    return train, val, test, len(label_map), weights

class ThreadDataset(Dataset):
    def __init__(self, thread_list, tokenizer, config):
        self.threads = thread_list
        self.tokenizer = tokenizer
        self.config = config

    def __len__(self): return len(self.threads)

    def __getitem__(self, idx):
        df = self.threads[idx].sort_values(self.config.TURN_COL)
        texts = df[self.config.TEXT_COL].astype(str).tolist()
        enc = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.config.MAX_LEN, return_tensors='pt')
        
        turns = np.clip(df[self.config.TURN_COL].values, 0, self.config.MAX_TURNS-1) if self.config.TURN_COL in df.columns else np.arange(len(df))
        parents = df[self.config.PARENT_COL].fillna(0).values if self.config.PARENT_COL in df.columns else np.zeros(len(df))
        is_reply = (parents != 0).astype(int)
        
        return {
            'input_ids': enc['input_ids'],
            'attention_mask': enc['attention_mask'],
            'turn_ids': torch.tensor(turns, dtype=torch.long),
            'reply_ids': torch.tensor(is_reply, dtype=torch.long),
            'labels': torch.tensor(df['label_idx'].values, dtype=torch.long)
        }

def collate_fn(batch):
    return {
        'input_ids': pad_sequence([b['input_ids'] for b in batch], batch_first=True, padding_value=0),
        'attention_mask': pad_sequence([b['attention_mask'] for b in batch], batch_first=True, padding_value=0),
        'turn_ids': pad_sequence([b['turn_ids'] for b in batch], batch_first=True, padding_value=0),
        'reply_ids': pad_sequence([b['reply_ids'] for b in batch], batch_first=True, padding_value=0),
        'labels': pad_sequence([b['labels'] for b in batch], batch_first=True, padding_value=-100)
    }

## 3. Model Architecture
class LiquidCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.W_tau, self.U_tau = nn.Linear(input_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_g, self.U_g = nn.Linear(input_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.epsilon = 1e-6 

    def forward(self, z_i, h_prev):
        # Tau calculation
        tau = F.softplus(self.W_tau(z_i) + self.U_tau(h_prev)) + self.epsilon
        # Update gate
        g = torch.tanh(self.W_g(z_i) + self.U_g(h_prev))
        # Liquid state update
        h_new = (1.0 - (1.0 / tau)) * h_prev + g
        return h_new

class PLITSNet_Hybrid(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        # PLM backbone
        self.plm = AutoModel.from_pretrained(config.PLM_NAME)
        # Initially freeze
        for p in self.plm.parameters(): p.requires_grad = False
            
        self.plm_dim = self.plm.config.hidden_size
        self.proj = nn.Linear(self.plm_dim, config.PROJ_DIM)
        self.dropout = nn.Dropout(config.DROPOUT)
        
        # Structure Embeddings
        self.turn_emb = nn.Embedding(config.MAX_TURNS, 16)
        self.reply_emb = nn.Embedding(2, 16)
        
        # Liquid Encoder
        self.liquid = LiquidCell(config.PROJ_DIM + 32, config.HIDDEN_DIM)
        self.hidden_dim = config.HIDDEN_DIM
        
        # Classifier
        self.classifier = nn.Linear(config.HIDDEN_DIM + config.PROJ_DIM, num_classes)
        self.next_step_pred = nn.Linear(config.HIDDEN_DIM, config.PROJ_DIM)

    def unfreeze_plm(self):
        print(">>> Unfreezing PLM for Fine-tuning...")
        for p in self.plm.parameters(): p.requires_grad = True

    def _get_embeddings(self, input_ids, attention_mask):
        B, T, L = input_ids.shape
        flat_out = self.plm(input_ids.view(-1, L), attention_mask=attention_mask.view(-1, L))
        mask = attention_mask.view(-1, L).unsqueeze(-1).float()
        # Mean pooling
        c_i = (flat_out.last_hidden_state * mask).sum(1) / torch.clamp(mask.sum(1), min=1e-9)
        e_i = self.proj(c_i) 
        return e_i.view(B, T, -1)

    def forward_features(self, input_ids, attention_mask, turn_ids, reply_ids):
        e_i = self._get_embeddings(input_ids, attention_mask)
        e_i_drop = self.dropout(e_i)
        
        # Structure Fusion
        s_i = torch.cat([self.turn_emb(turn_ids), self.reply_emb(reply_ids)], dim=-1)
        z_i = torch.cat([e_i_drop, s_i], dim=-1) 
        
        B, T, _ = z_i.shape
        h = torch.zeros(B, self.hidden_dim).to(z_i.device)
        h_seq = []
        for t in range(T):
            h = self.liquid(z_i[:, t, :], h)
            h_seq.append(h)
        h_seq = torch.stack(h_seq, dim=1)
        
        return h_seq, e_i_drop, e_i 

    def forward_pretrain(self, input_ids, attention_mask, turn_ids, reply_ids):
        h_seq, _, e_i_raw = self.forward_features(input_ids, attention_mask, turn_ids, reply_ids)
        preds = self.next_step_pred(h_seq) 
        preds_shifted = preds[:, :-1, :]
        targets_shifted = e_i_raw[:, 1:, :] 
        return preds_shifted, targets_shifted

    def forward_classify(self, input_ids, attention_mask, turn_ids, reply_ids):
        h_seq, e_i_drop, _ = self.forward_features(input_ids, attention_mask, turn_ids, reply_ids)
        # Concatenation for classification
        logits = self.classifier(torch.cat([h_seq, e_i_drop], dim=-1))
        return logits

## 4. Training Logic
def train_stage1_unsupervised(model, dataloader, epochs, device):
    print("\n--- Stage 1: Structure Learning (Frozen PLM) ---")
    # Unsupervised adaptation
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
    criterion = nn.MSELoss()
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items() if k!='labels'}
            
            preds, targets = model.forward_pretrain(inputs['input_ids'], inputs['attention_mask'],
                                                    inputs['turn_ids'], inputs['reply_ids'])
            
            if preds.shape[1] > 0:
                loss = criterion(preds, targets.detach()) 
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                total_loss += loss.item()
        
        if (epoch+1) % 5 == 0:
            print(f"  Epoch {epoch+1}/{epochs} | MSE: {total_loss/len(dataloader):.4f}")

def train_stage2_supervised(model, train_loader, val_loader, epochs, device, class_weights):
    print("\n--- Stage 2: Supervised Fine-tuning (Unfrozen PLM) ---")
    
    # UNFREEZE PLM for better performance on small data
    model.unfreeze_plm()
    
    plm_params = list(map(id, model.plm.parameters()))
    head_params = filter(lambda p: id(p) not in plm_params, model.parameters())
    
    # Differential LR: Small for PLM, Large for Head
    optimizer = torch.optim.AdamW([
        {'params': model.plm.parameters(), 'lr': 2e-5},
        {'params': head_params, 'lr': 1e-3}
    ])
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    
    best_f1 = 0
    best_acc = 0
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items() if k!='labels'}
            labels = batch['labels'].to(device)
            
            logits = model.forward_classify(inputs['input_ids'], inputs['attention_mask'],
                                            inputs['turn_ids'], inputs['reply_ids'])
            
            loss = criterion(logits.view(-1, logits.shape[-1]), labels.view(-1))
            loss.backward()
            
            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            total_loss += loss.item()
            
        # Eval
        model.eval()
        preds, trues = [], []
        with torch.no_grad():
            for batch in val_loader:
                inputs = {k: v.to(device) for k, v in batch.items() if k!='labels'}
                labels = batch['labels'].to(device)
                logits = model.forward_classify(inputs['input_ids'], inputs['attention_mask'],
                                                inputs['turn_ids'], inputs['reply_ids'])
                mask = labels != -100
                preds.extend(torch.argmax(logits, -1)[mask].cpu().numpy())
                trues.extend(labels[mask].cpu().numpy())
        
        acc = accuracy_score(trues, preds)
        f1 = f1_score(trues, preds, average='macro')
        
        # Log manually
        current_lr = optimizer.param_groups[0]['lr']
        print(f"  Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_loader):.4f} | Val F1: {f1:.4f} | LR: {current_lr:.2e}")
        
        scheduler.step(f1)
        
        if f1 > best_f1:
            best_f1 = f1
            best_acc = acc

    print(f"--- Best Val F1: {best_f1:.4f} (Acc: {best_acc:.4f}) ---")

## 5. Execution
train_threads, val_threads, test_threads, num_classes, weights = load_data_split(Config)
tokenizer = AutoTokenizer.from_pretrained(Config.PLM_NAME)

# Datasets
train_ds = ThreadDataset(train_threads, tokenizer, Config)
val_ds = ThreadDataset(val_threads, tokenizer, Config)
test_ds = ThreadDataset(test_threads, tokenizer, Config)
unsupervised_ds = ThreadDataset(train_threads + val_threads + test_threads, tokenizer, Config)

# Dataloaders
loader_args = {'batch_size': 8, 'collate_fn': collate_fn}
unsupervised_loader = DataLoader(unsupervised_ds, shuffle=True, **loader_args)
train_loader = DataLoader(train_ds, shuffle=True, **loader_args)
val_loader = DataLoader(val_ds, shuffle=False, **loader_args)
test_loader = DataLoader(test_ds, shuffle=False, **loader_args)

# Model
model = PLITSNet_Hybrid(Config, num_classes).to(Config.DEVICE)

# 1. Warm-up Structure Learning (Frozen PLM)
train_stage1_unsupervised(model, unsupervised_loader, epochs=10, device=Config.DEVICE)

# 2. Full Fine-Tuning (Unfrozen PLM, Differential LR)
train_stage2_supervised(model, train_loader, val_loader, epochs=15, device=Config.DEVICE, class_weights=weights)

# 3. Final Test
model.eval()
preds, trues = [], []
with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to(Config.DEVICE) for k, v in batch.items() if k!='labels'}
        labels = batch['labels'].to(Config.DEVICE)
        logits = model.forward_classify(inputs['input_ids'], inputs['attention_mask'],
                                        inputs['turn_ids'], inputs['reply_ids'])
        mask = labels != -100
        preds.extend(torch.argmax(logits, -1)[mask].cpu().numpy())
        trues.extend(labels[mask].cpu().numpy())

print("\n" + "="*30)
print(f"Final Test Accuracy: {accuracy_score(trues, preds):.4f}")
print(f"Final Test Macro-F1: {f1_score(trues, preds, average='macro'):.4f}")
print("="*30)