In [None]:
!pip install "torch>=2.5.0" "torch_xla[tpu]>=2.5.0" -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip install transformers datasets scikit-learn accelerate seaborn matplotlib -q

import re
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from datasets import load_dataset
from sklearn.metrics import f1_score, accuracy_score, hamming_loss, jaccard_score, precision_recall_curve
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

# --- CONFIGURATION ---
DEVICE = xm.xla_device()

BATCH_SIZE = 128
EPOCHS = 5
MAX_LEN = 128
MODEL_NAME = "roberta-base"
HIDDEN_DIM = 768

# --- DATA LOADING ---
print("\n1. Loading GoEmotions Dataset...")

dataset = load_dataset("go_emotions", "simplified")
emotion_labels = dataset["train"].features["labels"].feature.names
num_labels = len(emotion_labels)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def preprocess(example):
    vec = [0.0] * num_labels
    for label in example["labels"]:
        vec[label] = 1.0

    tokenized = tokenizer(example["text"], padding="max_length", truncation=True, max_length=MAX_LEN)
    example["input_ids"] = tokenized["input_ids"]
    example["attention_mask"] = tokenized["attention_mask"]
    example["label_vector"] = vec
    return example

print("\n2. Preprocessing data...")

dataset = dataset.map(preprocess, remove_columns=["text", "labels", "id"])
dataset.set_format(type="torch")
train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset["validation"], batch_size=BATCH_SIZE, drop_last=True)

# Class Weights (Handle Imbalance)
all_labels = torch.stack([x["label_vector"] for x in dataset["train"]])
pos_counts = all_labels.sum(dim=0)
pos_weight = (len(dataset["train"]) - pos_counts) / (pos_counts + 1e-6)
pos_weight = pos_weight.to(DEVICE)

# --- GRAPH CONSTRUCTION ---
print("\n3. Constructing Emotion Correlation Graph...")

occ_count = np.zeros(num_labels)
co_mat = np.zeros((num_labels, num_labels))
raw_train = load_dataset("go_emotions", "simplified")["train"]

for ex in raw_train:
    idxs = ex["labels"]
    for i in idxs:
        occ_count[i] += 1
        for j in idxs:
            if i != j: co_mat[i][j] += 1

conditional_prob = co_mat / (occ_count[:, None] + 1e-6)
initial_adj_np = conditional_prob + np.eye(num_labels)
initial_adj = torch.tensor(initial_adj_np, dtype=torch.float).to(DEVICE)

# --- MODEL DEFINITIONS ---
class BaselineModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(MODEL_NAME)
        self.classifier = nn.Linear(HIDDEN_DIM, num_labels)

    def forward(self, ids, mask):
        return self.classifier(self.transformer(ids, mask).last_hidden_state[:, 0, :])

class CoTEGModel(nn.Module):
    def __init__(self, adj):
        super().__init__()
        self.text_encoder = nn.Module()
        self.text_encoder.transformer = AutoModel.from_pretrained(MODEL_NAME)
        self.label_embeddings = nn.Embedding(num_labels, HIDDEN_DIM)
        self.A = nn.Parameter(adj.clone())
        self.gcn = nn.Module()
        self.gcn.linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
        self.gcn.norm = nn.LayerNorm(HIDDEN_DIM)

    def forward(self, ids, mask):
        text_repr = self.text_encoder.transformer(ids, mask).last_hidden_state[:, 0, :]
        curr_A = torch.relu(self.A)
        norm_A = curr_A / (curr_A.sum(1, keepdim=True) + 1e-6)
        weight_input = self.label_embeddings.weight
        gcn_out = self.gcn.norm(torch.relu(norm_A @ self.gcn.linear(weight_input)) + weight_input)
        return text_repr @ gcn_out.T

print(f"\n4. Training and Evaluating Models...")

# --- TRAINING ENGINE ---
def train_and_evaluate(model, name, use_graph_loss=False):

    print(f"\n- Training {name.upper()} Model...")

    model = model.to(DEVICE)

    transformer_params = []
    graph_params = []

    for name, param in model.named_parameters():
        if "transformer" in name or "text_encoder" in name:
            transformer_params.append(param)
        else:
            graph_params.append(param)

    optimizer = AdamW([
        {'params': transformer_params, 'lr': 2e-5},  # Low LR for BERT
        {'params': graph_params, 'lr': 1e-3}         # High LR for Graph
    ])

    scheduler = get_linear_schedule_with_warmup(optimizer, 0, len(train_loader)*EPOCHS)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    model.train()

    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_loader, [DEVICE])

        total_loss = 0

        for i, batch in enumerate(para_loader.per_device_loader(DEVICE)):
            optimizer.zero_grad()
            logits = model(batch["input_ids"], batch["attention_mask"])
            loss = criterion(logits, batch["label_vector"])

            if use_graph_loss:
                curr_A = torch.relu(model.A)
                norm_A = curr_A / (curr_A.sum(1, keepdim=True) + 1e-6)
                loss += 0.01 * torch.norm(norm_A - initial_adj, p='fro')

            loss.backward()
            xm.optimizer_step(optimizer)
            scheduler.step()
            current_loss = loss.item()
            total_loss += current_loss

            if i % 1 == 0:
                print(f"\rEpoch {epoch+1} | Step {i+1}/{len(train_loader)} | Loss: {current_loss:.4f}", end="")

        print(f"\nEpoch {epoch+1}/{EPOCHS} | Avg Loss: {total_loss/len(train_loader):.4f}")

    # --- EVALUATION ---
    print(f"\n- Evaluating {name}...")
    model.eval()

    model.to("cpu")

    probs_list, trues_list = [], []
    with torch.no_grad():
        for batch in val_loader:
            logits = model(batch["input_ids"], batch["attention_mask"])
            probs_list.append(torch.sigmoid(logits).numpy())
            trues_list.append(batch["label_vector"].numpy())

    probs = np.vstack(probs_list)
    trues = np.vstack(trues_list)

    # Thresholds
    thresholds = []
    for i in range(num_labels):
        p, r, t = precision_recall_curve(trues[:, i], probs[:, i])
        if len(t) == 0:
            thresholds.append(0.5)
            continue

        f1 = 2*p*r/(p+r+1e-6)
        best_idx = np.argmax(f1)
        thresholds.append(t[best_idx] if best_idx < len(t) else 0.5)

    preds = np.zeros_like(probs)
    for i in range(num_labels):
        preds[:, i] = (probs[:, i] > thresholds[i]).astype(int)

    metrics = {
        "macro_f1": float(f1_score(trues, preds, average="macro")),
        "weighted_f1": float(f1_score(trues, preds, average="weighted")),
        "exact_accuracy": float(accuracy_score(trues, preds)),
        "hamming_loss": float(hamming_loss(trues, preds))
    }

    print(f"- {name} Metrics: {metrics}")

    model.to(DEVICE)

    return model, thresholds, metrics

# --- RUN TRAINING ---
baseline_model, base_thrs, base_metrics = train_and_evaluate(BaselineModel(), "Baseline")
coteg_model, coteg_thrs, coteg_metrics = train_and_evaluate(CoTEGModel(initial_adj), "CoTEG", use_graph_loss=True)

# --- VISUALIZATION ---
def predict_emotions(text_input):
    print(f"\n5. VISUAL ANALYSIS FOR INPUT: \"{text_input}\"")

    # 1. SETUP
    baseline_model.eval().to("cpu")
    coteg_model.eval().to("cpu")

    # 2. SMART SPLITTING ("Divide and Conquer")
    # Split by 'but', 'however', or punctuation to isolate conflicting emotions
    chunks = re.split(r' but | however |[.!?]+', text_input)
    chunks = [c.strip() for c in chunks if len(c) > 5] # Clean up
    if not chunks: chunks = [text_input] # Fallback

    print(f"\n6. Logic: Split input into {len(chunks)} segments: {chunks}\n")

    # 3. INFERENCE LOOP
    # Track the MAXIMUM probability seen for each emotion across all chunks
    base_final_probs = np.zeros(num_labels)
    coteg_final_probs = np.zeros(num_labels)

    with torch.no_grad():
        for chunk in chunks:
            # Tokenize chunk
            tokenized = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_LEN)
            ids = tokenized["input_ids"]
            mask = tokenized["attention_mask"]

            # Run Inference
            base_logits = baseline_model(ids, mask)
            coteg_logits = coteg_model(ids, mask)

            chunk_base_probs = torch.sigmoid(base_logits)[0].numpy()
            chunk_coteg_probs = torch.sigmoid(coteg_logits)[0].numpy()

            # Max Pooling
            base_final_probs = np.maximum(base_final_probs, chunk_base_probs)
            coteg_final_probs = np.maximum(coteg_final_probs, chunk_coteg_probs)

    # 4. FILTERING

    data = []
    for i, label in enumerate(emotion_labels):
        if base_final_probs[i] > 0.05 or coteg_final_probs[i] > 0.05:
            data.append({"Emotion": label, "Score": base_final_probs[i], "Model": "Baseline"})
            data.append({"Emotion": label, "Score": coteg_final_probs[i], "Model": "CoTEG"})

    df = pd.DataFrame(data)

    if df.empty:
        print("No strong emotions detected!")
        return

In [None]:
from google.colab import drive

print("\n5. Connecting to Google Drive...")
drive.mount('/content/drive')

print("\n6. Saving models to the Drive root folder...")

baseline_model.to("cpu")
coteg_model.to("cpu")

torch.save({
    'state': coteg_model.state_dict(),
    'thr': coteg_thrs,
    'metrics': coteg_metrics
}, "/content/drive/My Drive/coteg_model.pth")

torch.save({
    'state': baseline_model.state_dict(),
    'thr': base_thrs,
    'metrics': base_metrics
}, "/content/drive/My Drive/baseline_model.pth")

print("Models saved successfully!")