In [None]:
# Cell 1: Environment Setup & Imports
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm

import logging
from transformers import logging as hf_logging


hf_logging.set_verbosity_error()

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Cell 2: Configuration (Optimized for Concat Model)
class Config:
    # Model: RoBERTa-wwm-ext
    MODEL_NAME = "hfl/chinese-roberta-wwm-ext"
    DATA_PATH = "DATA_PATH"  # Modify to the actual data file name and path after parsing.
    
    # --- Optimized Hyperparameters ---
    # 1. Batch Size: Kept small (16) to handle longer sequence lengths in Concat mode
    BATCH_SIZE = 32
    
    # 2. Learning Rate:  3e-4
    LEARNING_RATE = 3e-4
    
    # 3. Epochs: 10 (Allow convergence with lower LR)
    EPOCHS = 10
    
    # 4. Max Length: Increased to 512.
    # The paper notes PLMs are limited to 512 tokens.
    # Concat model needs max length to fit context + target.
    MAX_LEN = 256 
    
    GRADIENT_CLIPPING = 1.0
    LABEL_MAP = {"Negative": 0, "Neutral": 1, "Positive": 2}
    NUM_CLASSES = 3

config = Config()

# Cell 3: Data Preprocessing (Context Construction)
# This is the key difference from Flat RoBERTa.
# We must reconstruct the thread history for each post.

def load_and_construct_context(file_path):
    df = pd.read_csv(file_path)
    
    # Basic cleaning
    df = df.dropna(subset=['text', 'sentiment', 'thread_id', 'turn_index'])
    df['label'] = df['sentiment'].map(config.LABEL_MAP)
    df = df.dropna(subset=['label'])
    df['label'] = df['label'].astype(int)
    
    # Sort to ensure order: thread_id -> turn_index
    df = df.sort_values(by=['thread_id', 'turn_index'])
    
    context_texts = []
    target_texts = []
    
    # Group by thread to process conversation history
    # The paper mentions concatenating target post with "parent thread" 
    grouped = df.groupby('thread_id')
    
    for _, group in tqdm(grouped, desc="Constructing Threads"):
        # Convert group to list of texts
        texts = group['text'].tolist()
        
        # Determine context for each post in the thread
        # For turn 0: Context is empty (or special token)
        # For turn N: Context is concatenation of 0 to N-1
        current_thread_history = ""
        
        for i in range(len(texts)):
            target = str(texts[i])
            
            # Use [SEP] as a separator for history in raw text if needed, 
            # though tokenizer handles structure better.
            # Here we just accumulate plain text history.
            context = current_thread_history
            
            context_texts.append(context)
            target_texts.append(target)
            
            # Update history for the NEXT turn
            # We limit history length roughly here to avoid massive string ops, 
            # though Tokenizer will do the hard truncation.
            current_thread_history += target + " [SEP] " 

    # Re-assign back to dataframe to keep alignment
    # Note: simple append works because we iterate groups in order and df was sorted
    df['context_text'] = context_texts
    df['target_text'] = target_texts
    
    return df

print("Processing data to construct conversation context...")
df = load_and_construct_context(config.DATA_PATH)
print(f"Total processed samples: {len(df)}")
print("Sample Context (Row 1):", df.iloc[1]['context_text'])
print("Sample Target  (Row 1):", df.iloc[1]['target_text'])

# Split by thread_id
unique_threads = df['thread_id'].unique()
train_threads, test_threads = train_test_split(unique_threads, test_size=0.2, random_state=42)
train_threads, val_threads = train_test_split(train_threads, test_size=0.1, random_state=42)

train_df = df[df['thread_id'].isin(train_threads)].reset_index(drop=True)
val_df = df[df['thread_id'].isin(val_threads)].reset_index(drop=True)
test_df = df[df['thread_id'].isin(test_threads)].reset_index(drop=True)

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

# Cell 4: Compute Class Weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_df['label']),
    y=train_df['label']
)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print(f"Class Weights: {class_weights}")

# Cell 5: Concat Dataset Class
class ConcatWeiboDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.context = dataframe['context_text']
        self.target = dataframe['target_text']
        self.labels = dataframe['label']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        context_text = str(self.context[index])
        target_text = str(self.target[index])
        label = self.labels[index]

        # Use tokenizer to handle Sentence Pair classification
        # text_a = context, text_b = target
        # BERT Input: [CLS] context [SEP] target [SEP]
        encoding = self.tokenizer.encode_plus(
            text=context_text,
            text_pair=target_text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True, # Important for Concat to distinguish segments
            truncation=True, # Truncates context if too long, usually
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'targets': torch.tensor(label, dtype=torch.long)
        }

tokenizer = BertTokenizer.from_pretrained(config.MODEL_NAME)
train_dataset = ConcatWeiboDataset(train_df, tokenizer, config.MAX_LEN)
val_dataset = ConcatWeiboDataset(val_df, tokenizer, config.MAX_LEN)
test_dataset = ConcatWeiboDataset(test_df, tokenizer, config.MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)

# Cell 6: Concat RoBERTa Model Architecture
class ConcatRoBERTa(nn.Module):
    def __init__(self, n_classes):
        super(ConcatRoBERTa, self).__init__()
        self.bert = BertModel.from_pretrained(config.MODEL_NAME)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # We also pass token_type_ids here so RoBERTa knows which is context and which is target
        outputs = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        output = self.dropout(pooled_output)
        return self.classifier(output)

model = ConcatRoBERTa(config.NUM_CLASSES).to(device)

# Cell 7: Training Setup
optimizer = AdamW(model.parameters(), lr=config.LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss(weight=class_weights) # Weighted Loss

total_steps = len(train_loader) * config.EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=int(total_steps * 0.1), num_training_steps=total_steps
)

# Cell 8: Training Loop
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler):
    model.train()
    losses = []
    correct_predictions = 0
    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        targets = d["targets"].to(device)

        outputs = model(input_ids, attention_mask, token_type_ids)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, targets)

        correct_predictions += torch.sum(preds == targets)
        losses.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.GRADIENT_CLIPPING)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    return correct_predictions.double() / len(data_loader.dataset), np.mean(losses)

def eval_model(model, data_loader, loss_fn, device):
    model.eval()
    losses = []
    correct_predictions = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            token_type_ids = d["token_type_ids"].to(device)
            targets = d["targets"].to(device)
            
            outputs = model(input_ids, attention_mask, token_type_ids)
            _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, targets)
            
            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    return correct_predictions.double() / len(data_loader.dataset), np.mean(losses), f1_score(all_targets, all_preds, average='macro')

print("Starting Concat RoBERTa Training (Optimized)...")
best_macro_f1 = 0
for epoch in range(config.EPOCHS):
    train_acc, train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device, scheduler)
    val_acc, val_loss, val_f1 = eval_model(model, val_loader, loss_fn, device)
    print(f"Epoch {epoch+1}/{config.EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f}")
    
    if val_f1 > best_macro_f1:
        best_macro_f1 = val_f1
        torch.save(model.state_dict(), 'optimized_concat_roberta.bin')
        print("=> Saved Best Model")

# Cell 9: Final Evaluation
model.load_state_dict(torch.load('optimized_concat_roberta.bin'))
test_acc, test_loss, test_f1 = eval_model(model, test_loader, loss_fn, device)
print(f"\nFinal Test Accuracy: {test_acc:.4f}")
print(f"Final Test Macro F1: {test_f1:.4f}")

# Detailed Report
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for d in test_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        targets = d["targets"].to(device)
        outputs = model(input_ids, attention_mask, token_type_ids)
        _, preds = torch.max(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

print("\nClassification Report:")
print(classification_report(all_targets, all_preds, target_names=["Negative", "Neutral", "Positive"], zero_division=0))