In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import optuna

from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import random
import os
import copy

SEED = 42
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CSV_FILE_PATH = 'DATA_PATH' # Modify to the actual data file name and path after parsing.
PLM_MODEL_PATH = 'hfl/chinese-roberta-wwm-ext' 
MAX_LEN = 256
print(f"Running on: {DEVICE}")

In [None]:
class TreeNode:
    def __init__(self, mid, text, label=None):
        self.mid = mid
        self.text = text
        self.label = label
        self.children = []
        self.input_ids = None
        self.attention_mask = None
        # Using for storing hidden states during model processing
        self.state_h = None
        self.state_c = None

    def add_child(self, node):
        self.children.append(node)

def load_and_process_data(csv_path, tokenizer):
    df = pd.read_csv(csv_path)
    
    # Label Mapping
    unique_labels = df['sentiment'].unique()
    label2id = {l: i for i, l in enumerate(unique_labels)}
    id2label = {i: l for i, l in enumerate(unique_labels)}
    print(f"Label Mapping: {label2id}")
    
    # Building Trees
    trees = []
    # Group by thread_id to build separate trees
    grouped = df.groupby('thread_id')
    
    for _, group in tqdm(grouped, desc="Building Trees"):
        nodes = {}
        # 1. Bilding nodes
        for _, row in group.iterrows():
            mid = row['mid']
            text = str(row['text'])
            label = label2id[row['sentiment']]
            nodes[mid] = TreeNode(mid, text, label)
            
            # prepare tokenized inputs
            encoding = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=MAX_LEN,
                padding='max_length',
                truncation=True,
                return_attention_mask=True,
                return_tensors='pt'
            )
            nodes[mid].input_ids = encoding['input_ids'].squeeze(0) # (seq_len)
            nodes[mid].attention_mask = encoding['attention_mask'].squeeze(0)
            
        # 2. Building tree structure
        roots = []
        for _, row in group.iterrows():
            mid = row['mid']
            parent_id = row['parent_id']
            current_node = nodes[mid]
            
            # If no parent, it's a root node
            if parent_id == 0 or parent_id not in nodes:
                roots.append(current_node)
            else:
                nodes[parent_id].add_child(current_node)
        
        trees.extend(roots)
        
    return trees, label2id, id2label

# Loading tokenizer
tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_PATH)

# Loading and processing data
if os.path.exists(CSV_FILE_PATH):
    all_trees, label2id, id2label = load_and_process_data(CSV_FILE_PATH, tokenizer)
    NUM_CLASSES = len(label2id)
    
    # 70% Train, 15% Val, 15% Test
    train_val, test_trees = train_test_split(all_trees, test_size=0.15, random_state=SEED, shuffle=True)
    train_trees, val_trees = train_test_split(train_val, test_size=0.1765, random_state=SEED) 
    
    print(f"Total Trees: {len(all_trees)}")
    print(f"Train: {len(train_trees)}, Val: {len(val_trees)}, Test: {len(test_trees)}")
else:
    raise FileNotFoundError(f"Please upload {CSV_FILE_PATH}")

In [None]:
class ChildSumTreeLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Gates: i, o, u, 
        # Child-Sum: h_tilde = sum(h_children)
        self.W_iou = nn.Linear(input_dim, 3 * hidden_dim)
        self.U_iou = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        
        # Forget gate for each child
        self.W_f = nn.Linear(input_dim, hidden_dim)
        self.U_f = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x, child_c, child_h):
        """
        x: (1, input_dim) - current node input
        child_c: list of (1, hidden_dim) - list of child cell states
        child_h: list of (1, hidden_dim) - list of child hidden states
        Returns:
            h_new: (1, hidden_dim) - new hidden state
            c_new: (1, hidden_dim) - new cell state
        """
        # Calculate h_tilde
        if not child_h:
            h_sum = torch.zeros(1, self.hidden_dim).to(x.device)
        else:
            h_sum = torch.sum(torch.cat(child_h, dim=0), dim=0, keepdim=True)
            
        # Calculate i, o, u
        iou = self.W_iou(x) + self.U_iou(h_sum)
        i, o, u = torch.split(iou, self.hidden_dim, dim=1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        
        # Caculate f, and new c
        # c = i*u + sum(f_k * c_k)
        c_new = i * u
        
        if child_h:
            for h_k, c_k in zip(child_h, child_c):
                f_k = torch.sigmoid(self.W_f(x) + self.U_f(h_k))
                c_new += f_k * c_k
                
        h_new = o * torch.tanh(c_new)
        return h_new, c_new

class ReplyTreeRNN(nn.Module):
    def __init__(self, plm_name, hidden_dim, num_classes, dropout=0.3):
        super(ReplyTreeRNN, self).__init__()
        # 1. PLM
        self.plm = AutoModel.from_pretrained(plm_name)
        self.plm_config = self.plm.config
        embedding_dim = self.plm_config.hidden_size
        
        # 2. Tree-LSTM
        self.dropout = nn.Dropout(dropout)
        self.treelstm = ChildSumTreeLSTMCell(embedding_dim, hidden_dim)
        
        # 3. Classifier
        self.classifier = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, root_node, predictions_dict):
        """
        root_node: TreeNode - the root of the tree to process
        predictions_dict: dict - to store predictions for each node by mid
        Returns:
            None (predictions are stored in predictions_dict)
        """
        def traverse(node):
            # list to store child states
            child_h = []
            child_c = []
            for child in node.children:
                h, c = traverse(child)
                child_h.append(h)
                child_c.append(c)
            
            input_ids = node.input_ids.unsqueeze(0).to(DEVICE)
            att_mask = node.attention_mask.unsqueeze(0).to(DEVICE)
            
            # PLM Forward
            plm_out = self.plm(input_ids, attention_mask=att_mask)
            # Pooler Output (CLS + Linear + Tanh)
            # OR last_hidden_state[:, 0, :]
            node_emb = plm_out.pooler_output # (1, 768)
            node_emb = self.dropout(node_emb)
            
            # Tree-LSTM Forward
            h, c = self.treelstm(node_emb, child_c, child_h)
            
            # Prediction
            logits = self.classifier(h)
            predictions_dict[node.mid] = logits
            
            return h, c

        traverse(root_node)

In [None]:
def train_one_epoch(model, trees, optimizer, accumulation_steps=4):
    model.train()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    optimizer.zero_grad()
    
    random.shuffle(trees)
    
    step_loss = 0
    node_count = 0
    
    for i, root in enumerate(tqdm(trees, desc="Training", leave=False)):
        preds = {}
        model(root, preds) 
        
        # Loss
        loss = 0
        valid_nodes = 0
        
        # Collect loss from all nodes
        def collect_loss(node):
            nonlocal loss, valid_nodes
            if node.mid in preds:
                label = torch.tensor([node.label], device=DEVICE)
                loss += criterion(preds[node.mid], label)
                valid_nodes += 1
            for child in node.children:
                collect_loss(child)
        
        collect_loss(root)
        
        if valid_nodes > 0:
            loss = loss / valid_nodes # Loss average per node
            loss = loss / accumulation_steps # Gradient accumulation
            loss.backward()
            step_loss += loss.item() * accumulation_steps
        
        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            total_loss += step_loss
            step_loss = 0
            
    # Final step if not divisible
    if (i + 1) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()
        total_loss += step_loss
        
    return total_loss / len(trees)

def evaluate(model, trees):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for root in trees:
            preds = {}
            model(root, preds)
            
            def collect(node):
                if node.mid in preds:
                    p = torch.argmax(preds[node.mid], dim=1).item()
                    all_preds.append(p)
                    all_labels.append(node.label)
                for child in node.children:
                    collect(child)
            collect(root)
            
    if not all_labels: return 0.0, 0.0
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    return acc, f1

In [None]:
lr = 0.0003
hidden_dim = 256
dropout = 0.3

In [None]:
full_train_trees = train_trees + val_trees

print("Training Final Model with Best Params...")
final_model = ReplyTreeRNN(
    PLM_MODEL_PATH, 
    best_params['hidden_dim'], 
    NUM_CLASSES, 
    best_params['dropout']
).to(DEVICE)

optimizer = AdamW(final_model.parameters(), lr=best_params['lr'])

EPOCHS = 10
for epoch in range(EPOCHS):
    loss = train_one_epoch(final_model, full_train_trees, optimizer, accumulation_steps=8)
    test_acc, test_f1 = evaluate(final_model, test_trees)
    print(f"Epoch {epoch+1} | Loss: {loss:.4f} | Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f}")

acc, f1 = evaluate(final_model, test_trees)
print(f"Accuracy: {acc:.4f}")
print(f"Macro-F1: {f1:.4f}")
print("="*30)