In [None]:
# Cell 1: Environment Setup & Imports
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, classification_report
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm

# Set random seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Cell 2: Configuration
class Config:
    # Model backbone
    MODEL_NAME = "hfl/chinese-roberta-wwm-ext"
    DATA_PATH = "DATA_PATH"  # Modify to the actual data file name and path after parsing.
    
    # --- FIXED Hyperparameters ---
    BATCH_SIZE = 32
    LEARNING_RATE = 0.0003 
    EPOCHS = 10
    DROPOUT = 0.3
    MAX_POST_LEN = 256
    GRADIENT_CLIPPING = 1.0
    
    # --- Architecture Parameters ---
    EMBEDDING_DIM = 768
    HIDDEN_DIM = 768  # Match BERT output
    LSTM_LAYERS = 2
    NUM_ATTENTION_HEADS = 8  # Multi-head attention
    
    # --- Fine-tuning Config ---
    UNFREEZE_LAYERS = 3  # Unfreeze last 3 BERT layers
    BERT_LR = 0.0003  # Same as base LR for simplicity
    
    # --- Advanced Training Config ---
    WARMUP_RATIO = 0.15
    USE_FOCAL_LOSS = False
    FOCAL_ALPHA = 0.25
    FOCAL_GAMMA = 2.0
    LABEL_SMOOTHING = 0.1
    
    LABEL_MAP = {"Negative": 0, "Neutral": 1, "Positive": 2}
    NUM_CLASSES = 3

config = Config()

# Cell 3: Data Loading & Robust Thread Structuring
def load_thread_data(file_path):
    df = pd.read_csv(file_path)
    
    # Critical: Drop rows where ANY essential info is missing
    df = df.dropna(subset=['text', 'sentiment', 'thread_id', 'turn_index'])
    
    # Ensure types
    df['label'] = df['sentiment'].map(config.LABEL_MAP)
    df = df.dropna(subset=['label'])
    df['label'] = df['label'].astype(int)
    df['text'] = df['text'].astype(str)
    
    # Sort strictly
    df = df.sort_values(by=['thread_id', 'turn_index'])
    
    threads = []
    labels = []
    
    # Group extraction
    grouped = df.groupby('thread_id')
    for thread_id, group in grouped:
        group_texts = group['text'].tolist()
        group_labels = group['label'].tolist()
        
        if len(group_texts) != len(group_labels):
            print(f"Warning: Thread {thread_id} mismatch. Texts: {len(group_texts)}, Labels: {len(group_labels)}. Skipping.")
            continue
            
        threads.append(group_texts)
        labels.append(group_labels)
        
    return threads, labels, df['label'].values

print("Loading data with integrity checks...")
threads, thread_labels, all_labels = load_thread_data(config.DATA_PATH)
print(f"Successfully loaded {len(threads)} threads.")

# Split Dataset
train_texts, test_texts, train_y, test_y = train_test_split(
    threads, thread_labels, test_size=0.2, random_state=42
)
train_texts, val_texts, train_y, val_y = train_test_split(
    train_texts, train_y, test_size=0.1, random_state=42
)

print(f"Train Threads: {len(train_texts)}, Val: {len(val_texts)}, Test: {len(test_texts)}")

# Class Weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(all_labels),
    y=all_labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print(f"Class Weights: {class_weights}")

# Cell 4: Dataset and Collate Function
class ThreadDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        thread_posts = self.texts[index]
        thread_labels = self.labels[index]
        
        assert len(thread_posts) == len(thread_labels), f"Sample {index} mismatch"
        
        encoded_posts = []
        for post in thread_posts:
            encoded = self.tokenizer.encode_plus(
                str(post),
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                truncation=True,
                return_attention_mask=True,
                return_tensors='pt'
            )
            encoded_posts.append({
                'input_ids': encoded['input_ids'].squeeze(0),
                'attention_mask': encoded['attention_mask'].squeeze(0)
            })
            
        return {
            'encoded_posts': encoded_posts,
            'labels': thread_labels,
            'thread_len': len(thread_posts)
        }

def collate_fn(batch):
    max_thread_len = max(item['thread_len'] for item in batch)
    batch_size = len(batch)
    
    padded_input_ids = torch.zeros(batch_size, max_thread_len, config.MAX_POST_LEN, dtype=torch.long)
    padded_attention_masks = torch.zeros(batch_size, max_thread_len, config.MAX_POST_LEN, dtype=torch.long)
    padded_labels = torch.full((batch_size, max_thread_len), -100, dtype=torch.long)
    lengths = []
    
    for i, item in enumerate(batch):
        posts = item['encoded_posts']
        labels = item['labels']
        thread_len = item['thread_len']
        
        assert len(posts) == len(labels) == thread_len, \
            f"Batch item {i}: posts {len(posts)}, labels {len(labels)}, thread_len {thread_len}"
        
        for j in range(thread_len):
            padded_input_ids[i, j] = posts[j]['input_ids']
            padded_attention_masks[i, j] = posts[j]['attention_mask']
            padded_labels[i, j] = labels[j]
        
        lengths.append(thread_len)
    
    return {
        'input_ids': padded_input_ids,
        'attention_mask': padded_attention_masks,
        'labels': padded_labels,
        'lengths': torch.tensor(lengths, dtype=torch.long)
    }

tokenizer = BertTokenizer.from_pretrained(config.MODEL_NAME)
train_dataset = ThreadDataset(train_texts, train_y, tokenizer, config.MAX_POST_LEN)
val_dataset = ThreadDataset(val_texts, val_y, tokenizer, config.MAX_POST_LEN)
test_dataset = ThreadDataset(test_texts, test_y, tokenizer, config.MAX_POST_LEN)

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print("\nData loaders created successfully!")

# Cell 5: Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, weight=None, ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, 
                                   ignore_index=self.ignore_index, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

# Cell 6: Multi-Head Self-Attention Module
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = 1.0 / np.sqrt(self.head_dim)
        
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, hidden_dim = x.shape
        
        # Linear projections and reshape for multi-head
        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        
        # Apply mask
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, seq_len]
            scores = scores.masked_fill(~mask, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_dim)
        
        # Output projection and residual
        output = self.out_proj(context)
        output = self.layer_norm(x + self.dropout(output))
        
        return output

# Cell 7: Enhanced Attn-BiLSTM Model with Fine-tuning
class EnhancedAttnBiLSTM(nn.Module):
    def __init__(self, config):
        super(EnhancedAttnBiLSTM, self).__init__()
        
        # Load BERT and selectively unfreeze
        self.bert = BertModel.from_pretrained(config.MODEL_NAME)
        self._freeze_bert_selectively(config.UNFREEZE_LAYERS)
        
        # Feature fusion layer
        self.feature_fusion = nn.Sequential(
            nn.Linear(config.EMBEDDING_DIM, config.EMBEDDING_DIM),
            nn.LayerNorm(config.EMBEDDING_DIM),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT)
        )
        
        # BiLSTM
        self.lstm = nn.LSTM(
            input_size=config.EMBEDDING_DIM,
            hidden_size=config.HIDDEN_DIM // 2,
            num_layers=config.LSTM_LAYERS,
            bidirectional=True,
            batch_first=True,
            dropout=config.DROPOUT if config.LSTM_LAYERS > 1 else 0
        )
        
        # Multi-Head Self-Attention
        self.attention = MultiHeadSelfAttention(
            config.HIDDEN_DIM, 
            num_heads=config.NUM_ATTENTION_HEADS, 
            dropout=config.DROPOUT
        )
        
        # Enhanced Classifier
        self.classifier = nn.Sequential(
            nn.Linear(config.HIDDEN_DIM, config.HIDDEN_DIM // 2),
            nn.LayerNorm(config.HIDDEN_DIM // 2),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.HIDDEN_DIM // 2, config.NUM_CLASSES)
        )
        
    def _freeze_bert_selectively(self, unfreeze_layers):
        # Freeze all layers first
        for param in self.bert.parameters():
            param.requires_grad = False
        
        # Unfreeze last N encoder layers
        if unfreeze_layers > 0:
            for layer in self.bert.encoder.layer[-unfreeze_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
            
            # Unfreeze pooler
            for param in self.bert.pooler.parameters():
                param.requires_grad = True
                
            print(f"✓ Unfroze last {unfreeze_layers} BERT layers for fine-tuning")

    def forward(self, input_ids, attention_mask, lengths):
        batch_size, max_thread_len, max_post_len = input_ids.shape
        
        # Flatten for BERT
        flat_input_ids = input_ids.view(batch_size * max_thread_len, max_post_len)
        flat_masks = attention_mask.view(batch_size * max_thread_len, max_post_len)
        
        # BERT encoding (with gradient for unfrozen layers)
        outputs = self.bert(flat_input_ids, flat_masks)
        post_embeddings = outputs.pooler_output
        
        # Feature fusion
        post_embeddings = self.feature_fusion(post_embeddings)
        
        # Reshape for LSTM
        lstm_input = post_embeddings.view(batch_size, max_thread_len, -1)
        
        # BiLSTM
        packed_input = pack_padded_sequence(
            lstm_input, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_output, _ = self.lstm(packed_input)
        lstm_output, _ = pad_packed_sequence(
            packed_output, batch_first=True, total_length=max_thread_len
        )
        
        # Multi-Head Attention
        thread_mask = torch.arange(max_thread_len, device=input_ids.device)[None, :] < lengths[:, None].to(input_ids.device)
        attn_output = self.attention(lstm_output, mask=thread_mask)
        
        # Classification
        logits = self.classifier(attn_output)
        
        return logits

# Cell 8: Initialize Model
model = EnhancedAttnBiLSTM(config).to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel Statistics:")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable ratio: {trainable_params/total_params*100:.2f}%")

# Cell 9: Setup Training Components
# Separate parameter groups for different learning rates
bert_params = []
other_params = []

for name, param in model.named_parameters():
    if param.requires_grad:
        if 'bert' in name:
            bert_params.append(param)
        else:
            other_params.append(param)

optimizer = AdamW([
    {'params': bert_params, 'lr': config.BERT_LR},
    {'params': other_params, 'lr': config.LEARNING_RATE}
])

# Learning rate scheduler
total_steps = len(train_loader) * config.EPOCHS
warmup_steps = int(total_steps * config.WARMUP_RATIO)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Loss function
if config.USE_FOCAL_LOSS:
    loss_fn = FocalLoss(
        alpha=config.FOCAL_ALPHA,
        gamma=config.FOCAL_GAMMA,
        weight=class_weights,
        ignore_index=-100
    )
    print(f"✓ Using Focal Loss (alpha={config.FOCAL_ALPHA}, gamma={config.FOCAL_GAMMA})")
else:
    loss_fn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-100)
    print("✓ Using CrossEntropy Loss")

print(f"✓ Warmup steps: {warmup_steps}/{total_steps}")

# Cell 10: Training Loop
def train_epoch(model, data_loader, loss_fn, optimizer, scheduler):
    model.train()
    losses = []
    all_preds = []
    all_targets = []
    
    progress_bar = tqdm(data_loader, desc="Training")
    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        masks = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        lengths = batch['lengths']
        
        optimizer.zero_grad()
        
        logits = model(input_ids, masks, lengths)
        
        # Flatten for loss
        active_logits = logits.view(-1, config.NUM_CLASSES)
        active_labels = labels.view(-1)
        
        loss = loss_fn(active_logits, active_labels)
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), config.GRADIENT_CLIPPING)
        optimizer.step()
        scheduler.step()
        
        losses.append(loss.item())
        
        # Metrics
        preds = torch.argmax(active_logits, dim=1)
        mask = active_labels != -100
        all_preds.extend(preds[mask].cpu().numpy())
        all_targets.extend(active_labels[mask].cpu().numpy())
        
        if batch_idx % 5 == 0:
            current_lr = scheduler.get_last_lr()[0]
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{current_lr:.2e}'
            })
        
    return np.mean(losses), f1_score(all_targets, all_preds, average='macro'), accuracy_score(all_targets, all_preds)

def eval_model(model, data_loader, loss_fn):
    model.eval()
    losses = []
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            masks = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            lengths = batch['lengths']
            
            logits = model(input_ids, masks, lengths)
            
            active_logits = logits.view(-1, config.NUM_CLASSES)
            active_labels = labels.view(-1)
            
            loss = loss_fn(active_logits, active_labels)
            losses.append(loss.item())
            
            preds = torch.argmax(active_logits, dim=1)
            mask = active_labels != -100
            all_preds.extend(preds[mask].cpu().numpy())
            all_targets.extend(active_labels[mask].cpu().numpy())
            
    return np.mean(losses), f1_score(all_targets, all_preds, average='macro'), accuracy_score(all_targets, all_preds)

# Cell 11: Training
print("\n" + "="*70)
print("Starting Enhanced Attn-BiLSTM Training with Fine-tuning...")
print("="*70)

best_f1 = 0
best_acc = 0
patience = 0
max_patience = 5

for epoch in range(config.EPOCHS):
    print(f"\nEpoch {epoch+1}/{config.EPOCHS}")
    print("-" * 70)
    
    train_loss, train_f1, train_acc = train_epoch(model, train_loader, loss_fn, optimizer, scheduler)
    val_loss, val_f1, val_acc = eval_model(model, val_loader, loss_fn)
    
    print(f"Train Loss: {train_loss:.4f} | F1: {train_f1:.4f} | Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | F1: {val_f1:.4f} | Acc: {val_acc:.4f}")
    
    # Save best model
    if val_f1 > best_f1:
        best_f1 = val_f1
        best_acc = val_acc
        patience = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_f1': val_f1,
            'val_acc': val_acc
        }, 'best_enhanced_model.bin')
        print(f"✓ Saved Best Model (F1: {best_f1:.4f}, Acc: {best_acc:.4f})")
    else:
        patience += 1
        if patience >= max_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

# Cell 12: Final Evaluation
print("\n" + "="*70)
print("Final Evaluation on Test Set")
print("="*70)

checkpoint = torch.load('best_enhanced_model.bin')
model.load_state_dict(checkpoint['model_state_dict'])
test_loss, test_f1, test_acc = eval_model(model, test_loader, loss_fn)

print(f"\nTest Results:")
print(f"  Accuracy:  {test_acc:.4f}")
print(f"  Macro F1:  {test_f1:.4f}")
print(f"  Test Loss: {test_loss:.4f}")

# Detailed Report
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        masks = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        lengths = batch['lengths']
        
        logits = model(input_ids, masks, lengths)
        active_logits = logits.view(-1, config.NUM_CLASSES)
        active_labels = labels.view(-1)
        
        preds = torch.argmax(active_logits, dim=1)
        mask = active_labels != -100
        all_preds.extend(preds[mask].cpu().numpy())
        all_targets.extend(active_labels[mask].cpu().numpy())

print("\n" + "="*70)
print("Classification Report:")
print("="*70)
print(classification_report(all_targets, all_preds, target_names=["Negative", "Neutral", "Positive"], zero_division=0))