In [2]:
from transformers import AutoTokenizer, AutoModel

In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader

# DOWNLOAD PRETRAINED MODEL AND TOKENIZER
roberta_tokenizer = AutoTokenizer.from_pretrained()
roberta_base = AutoModel.from_pretrained()

Some weights of RobertaModel were not initialized from the model checkpoint at C:\Users\subin\.cache\huggingface\hub\models--roberta-base\snapshots\roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
roberta_base.config.hidden_size

768

In [19]:
# Remove the pooler layer
roberta_base.pooler = None

In [None]:
import numpy as np
import pandas as pd
import ast

df_train = ""       # enter path to train dataset
df_val = ""         # enter path to validation dataset
df_test = ""        # enter path to test dataset

In [None]:
class EmoPillars_Dataset(Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer):
        self.tokenizer = tokenizer
        self.data = data
        self.max_len = 64
        self.target_cols = [str(i) for i in range(28)]
        self.soft_target_cols = [str(f"{i}_exp") for i in range(28)]

    def __len__(self):
        return(len(self.data))
    
    def __get_item__(self, idx):
        item = self.data.iloc[idx]
        text = str(item.utterance)
        encoding = self.tokenizer.encode_plus(text,
                                            add_special_tokens=True,
                                            truncation=True,
                                            return_tensors='pt',
                                            max_length=self.max_len,
                                            padding='max_length',
                                            return_attention_mask=True)
        
        # convert "['...']" → ['...']
        labels = ast.literal_eval(item.emotions_used_to_generate_context)
        expressiveness = ast.literal_eval(item.expressiveness)
        
        target = torch.tensor(item[self.target_cols].values.astype('float32'))
        soft_target = torch.tensor(item[self.soft_target_cols].values.astype('float32'))

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "atten_masks": encoding["attention_mask"].squeeze(0),
            "label_names": labels,
            "expressiveness": expressiveness,
            "hard_target": target,
            "soft_target": soft_target
        }

In [6]:
data = EmoPillars_Dataset(df_train, roberta_tokenizer)

In [7]:
len(data)

77477

In [8]:
data.__get_item__(1)

{'input_ids': tensor([   0,  100,  399,   75, 1057, 7738,    7,   28,   98,  490,    8, 5322,
           59,   39, 6453,    4,    2,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1]),
 'atten_masks': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'label_names': ['surprise', 'admiration', 'curiosity'],
 'expressiveness': [0.8, 0.6, 0.3],
 'hard_target': tensor([1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]),
 'soft_target': tensor([0.6000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

In [None]:
# Data Loaders
train_dataloader = DataLoader(EmoPillars_Dataset(df_train, roberta_tokenizer), batch_size=32, num_workers=4)
val_dataloader = DataLoader(EmoPillars_Dataset(df_val, roberta_tokenizer), batch_size=32, num_workers=4)
test_dataloader = DataLoader(EmoPillars_Dataset(df_test, roberta_tokenizer), batch_size=32, num_workers=4)

In [None]:
import json
import torch

# Load Precomputed Label Embeddings (JSON → Tensor)
with open("label_embeddings.json", "r") as f:
    emo_embed_raw = json.load(f)

# Convert lists -> torch tensors (float32)
emo_embed = {}
for k, v in emo_embed_raw.items():
    emo_embed[str(k)] = torch.tensor(v, dtype=torch.float32)

## Model Architecture Design

In [None]:
import os
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from transformers.optimization import get_cosine_schedule_with_warmup
from sklearn.metrics import f1_score

In [None]:
class Encoder(nn.Module):
    """
    Text Encoder:
    - Takes in tokenized text (from tokenizer)
    - Generates the text embedding vector
    """
    def __init__(self, base_encoder):
        super().__init__()
        self.encoder = base_encoder

    def forward(self, inputs):
        """
        inputs: tokenizer output dict (input_ids, attention_mask)
        """
        outputs = self.encoder(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]                                                                   # [B, T, H]

        atten_mask = inputs['attention_mask']                                                                           # [B, T]
        # Mean pooling text_embeddings
        atten_mask = atten_mask.unsqueeze(-1).float()
        pooled_text_emb = (last_hidden_state * atten_mask).sum(dim=1) / atten_mask.sum(dim=1).clamp(min=1e-9)           # [B, H]

        return {
            "text_embed": pooled_text_emb,
            "last_hidden_state": last_hidden_state,
            "atten_mask": atten_mask
        }

In [None]:
class CrossAttentionModule(nn.Module):
    def __init__(self, label_embedding_dict):
        super().__init__()
        
        self.label_embedding_dict = {
            k: v.clone().detach() for k, v in label_embedding_dict.items()
        }
        # Multi-head cross attention 
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.hidden_size,
            num_heads=6,
            batch_first=True
        )
        self.layer_norm = nn.LayerNorm(768)

    def forward(self, encoder_out, emotion_labels, expressiveness):
        device = next(self.parameters()).device
        
        emotion_emb = []
        for labels, weights in zip(emotion_labels, expressiveness):
            # stack embeddings of all emotion descriptions
            emb_list = [self.label_embedding_dict[lbl].to(device) for lbl in labels]
            emb_stack = torch.stack(emb_list, dim=0)                                                       # [num_labels, H]

            # normalize weights → weighted mean
            w = torch.tensor(weights, dtype=torch.float32, device=device).unsqueeze(1)
            w = w / w.sum()
            weighted_emb = (emb_stack * w).sum(dim=0)                                                      # [H]
            emotion_emb.append(weighted_emb)

        emotion_emb = torch.stack(emotion_emb, dim=0).unsqueeze(1)                                         # [B, 1, H]

        # cross-attention (query=text, key/value=emotion)
        attn_out, _ = self.cross_attn(
            query = encoder_out["last_hidden_state"],                     # [B, T, H]   
            key = emotion_emb,                                            # [B, 1, H]
            value = emotion_emb                                           # [B, 1, H]
        )

        # Fuse and pool
        fused_hidden_state = encoder_out["last_hidden_state"] + attn_out
        atten_mask = encoder_out["atten_mask"]
        fused_emo_text_emb = (fused_hidden_state * atten_mask).sum(dim=1) / atten_mask.sum(dim=1).clamp(min=1e-9)
        fused_emo_text_emb = self.layer_norm(fused_emo_text_emb)

        return fused_emo_text_emb

In [None]:
class PosteriorNetwork(nn.Module):
    """
    Learns an emotion-aware posterior distribution q(z | x, e)
    over latent space using fused encoder output.
    """
    def __init__(self, input_dim=768, latent_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.LayerNorm(256)
        )
        self.mu_posterior = nn.Linear(256, latent_dim)
        self.logvar_posterior = nn.Linear(256, latent_dim)

    def forward(self, fused_emo_text_emb):
        h = self.mlp(fused_emo_text_emb)
        mu_post = self.mu_posterior(h)
        logvar_post = torch.clamp(self.logvar_posterior(h), min=-10, max=10)

        # Reparameterization trick: sample z ~ N(mu, sigma^2)
        std = torch.exp(0.5 * logvar_post)
        eps = torch.randn_like(std)
        z = mu_post + eps * std
        
        return z, mu_post, logvar_post

In [None]:
class PriorNetwork(nn.Module):
    """
    Learns a prior distribution p(z | x)
    based only on text (without emotion labels).
    """
    def __init__(self, input_dim=768, latent_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.LayerNorm(256)
        )
        self.mu_prior = nn.Linear(256, latent_dim)
        self.logvar_prior = nn.Linear(256, latent_dim)

    def forward(self, text_emb):
        h = self.mlp(text_emb)
        mu_prior = self.mu_prior(h)
        logvar_prior = torch.clamp(self.logvar_prior(h), min=-10, max=10)
        
        return mu_prior, logvar_prior

In [None]:
class Classifier(nn.Module):
    """
    Final shared emotion classifier layer
    """
    def __init__(self, latent_dim=128, num_classes=28):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.GELU(),
            nn.LayerNorm(256),
            nn.Dropout(0.25),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.LayerNorm(128),
            nn.Dropout(0.25),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, z):
        return self.mlp(z)

In [None]:
# Main Model class
class EmoAxis(nn.Module):
    def __init__(self, encoder, cross_atten_module, posterior_net, prior_net, classifier):
        """
        Architecture:
        - EncoderBlock   → produces text and fused(text+emotion) embeddings
        - PosteriorNet   → q(z|x,e)
        - PriorNet       → p(z|x)
        - EmotionClassifier → predicts emotions from latent z
        """
        super().__init__()
        self.encoder = encoder
        self.cross_atten = cross_atten_module
        self.posterior = posterior_net
        self.prior = prior_net
        self.classifier = classifier


    def total_params(self):
        """Utility function to check trainable vs total params."""
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total parameters: {total:,}")
        print(f"Trainable parameters: {trainable:,}")


    def forward(self, input_ids, atten_mask, emotion_labels, expressiveness):
        # Encoder
        encoder_outputs = self.encoder(
            inputs = {"input_ids": input_ids, "attention_mask": atten_mask},
        )
        text_emb = encoder_outputs["text_embed"]                  # [B, H]

        # Cross Attention
        fused_emb = self.cross_atten(encoder_outputs, emotion_labels, expressiveness)    # [B, H]

        # Posterior Net
        z_post, mu_post, logvar_post = self.posterior(fused_emb)
        
        # Prior Net
        mu_prior, logvar_prior = self.prior(text_emb)
        z_prior = mu_prior + torch.exp(0.5 * logvar_prior) * torch.randn_like(mu_prior)

        # 4. Classifier
        logits_post = self.classifier(z_post)           # from sampled posterior
        logits_prior = self.classifier(z_prior)         # from prior mean

        return {
            "mu_post": mu_post,
            "logvar_post": logvar_post,
            "mu_prior": mu_prior,
            "logvar_prior": logvar_prior,
            "logits_post": logits_post,
            "logits_prior": logits_prior
        }
    

    def inference(self, input_ids, atten_mask):
        """
        Predict emotions from raw text (inference using prior p(z|x)).
        """
        self.eval()
        with torch.no_grad():
            encoder_outputs = self.encoder(
                inputs = {"input_ids": input_ids, "attention_mask": atten_mask},
            )
            # get only text embeddings 
            text_emb = encoder_outputs["text_embed"]

            # latent mean from prior
            mu_prior, logvar_prior = self.prior(text_emb)

            # classification
            logits = self.classifier(mu_prior)

            return logits

In [None]:
encoder = Encoder(base_encoder=roberta_base)

cross_atten_module = CrossAttentionModule(label_embedding_dict=emo_embed)

posterior_net = PosteriorNetwork()
prior_net = PriorNetwork()
classifier = Classifier()

# Initialize model
model = EmoAxis(
    encoder=encoder,
    cross_atten_module=cross_atten_module,
    posterior_net=posterior_net,
    prior_net=prior_net,
    classifier=classifier
)

model.total_params()

Total parameters: 128,264,604
Trainable parameters: 78,649,500


## Loss function

In [None]:
def kl_divergence(mu_post, logvar_post, mu_prior, logvar_prior):
    """
    KL(N(mu_q, var_q) || N(mu_p, var_p)) averaged over batch.
    Uses diagonal covariance (logvar = log(sigma^2))

    Formula: 0.5 * sum( log(var_p/var_q) + (var_q + (mu_q-mu_p)^2)/var_p - 1 )
    """
    term = logvar_prior - logvar_post + (torch.exp(logvar_post) + (mu_post - mu_prior) ** 2) / torch.exp(logvar_prior) - 1.0
    kl = 0.5 * torch.sum(term, dim=1)
    
    return kl.mean()
    

def compute_loss(logits_post, logits_prior, mu_post, logvar_post,
                 mu_prior, logvar_prior, hard_target, soft_target,
                 epoch, lambda_soft=1.0, lambda_kl=0.1):

    # BCE - supervised (posterior only)
    loss_bce = F.binary_cross_entropy_with_logits(logits_post, hard_target)

    #Soft MSE - posterior and prior (sigmoid outputs)
    probs_post = torch.sigmoid(logits_post)
    loss_soft_post = F.mse_loss(probs_post, soft_target) 
    if epoch < 1:
        loss_soft = loss_soft_post
    else:
        probs_prior = torch.sigmoid(logits_prior)
        loss_soft_prior = F.mse_loss(probs_prior, soft_target)
        loss_soft = loss_soft_post + loss_soft_prior

    # KL Divergence between posterior and prior
    loss_kl = kl_divergence(mu_post, logvar_post, mu_prior, logvar_prior)

    # Total loss  
    total_loss = loss_bce + lambda_soft * loss_soft + lambda_kl * loss_kl
    
    return total_loss, {"loss_bce": loss_bce.item(), "loss_soft": loss_soft.item(), "loss_kl": loss_kl.item()}

## Model Training

In [None]:
def freeze_encoder_layers(model, freeze_upto: int):
    """
    Freezes encoder layers from 0 → freeze_upto (inclusive).
    Example: freeze_upto = 5 → freeze encoder.layer[0] ... encoder.layer[5].
    """
    for name, param in model.named_parameters():
        param.requires_grad = True                      # unfreeze all layers

    # freeze
    for layer_idx in range(freeze_upto + 1):
        for param in model.encoder.roberta.encoder.layer[layer_idx].parameters():
            param.requires_grad = False

    print(f"\n\nFrozen encoder layers - 0 to {freeze_upto}\n\n")

In [None]:
def validate(model, val_dataloader, device, threshold=0.5):
    model.eval()
    preds_all = []
    truths_all = []

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['atten_masks'].to(device)
            hard_target = batch['hard_target'].to(device)

            # inference
            logits = model.inference(input_ids, attention_mask)

            probs = torch.sigmoid(logits).cpu().numpy()
            preds = (probs >= threshold).astype(int)
            truths = hard_target.cpu().numpy().astype(int)

            preds_all.append(preds)
            truths_all.append(truths)

    preds_all = np.concatenate(preds_all, axis=0)
    truths_all = np.concatenate(truths_all, axis=0)

    micro_f1 = f1_score(truths_all, preds_all, average='micro', zero_division=0)
    macro_f1 = f1_score(truths_all, preds_all, average='macro', zero_division=0)
    
    return {"micro_f1": micro_f1, "macro_f1": macro_f1}

In [None]:
# Main model training function
def train(
    model: nn.Module,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    device: torch.device,
    epochs: int = 10,
    lr_encoder: float = 2e-5,
    lr_other: float = 1e-4,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.1,
    kl_anneal_ratio: float = 0.25,
    gradient_accumulation_steps: int = 4,
    max_grad_norm: float = 1.0,
    use_amp: bool = True,
    early_stopping_patience: int = 3,
    min_epochs_before_stop: int = 3
):
    model.to(device)

    # Separate RoBERTa-encoder params &  other_params for different learning rate
    encoder_params, other_params = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if 'encoder' in name:
            encoder_params.append(p)
        else:
            other_params.append(p)

    optimizer = AdamW([
        {"params": encoder_params, "lr": lr_encoder},
        {"params": other_params, "lr": lr_other}
    ], weight_decay=weight_decay)

    steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    total_steps = steps_per_epoch * epochs

    warmup_steps = int(total_steps * warmup_ratio)
    kl_anneal_steps = int(total_steps * kl_anneal_ratio)

    # Build scheduler based on step count
    scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps
            )

    scaler = GradScaler(enabled=(use_amp and device.type =='cuda'))

    global_step = 0
    best_val_microF1 = -1.0
    epochs_no_improve = 0

    # Training Loop
    for epoch in range(0, epochs):
        model.train()
        epoch_loss = 0.0
        optimizer.zero_grad()

        # gradual encoder layers freezing
        if epoch < 2:
            freeze_encoder_layers(model, freeze_upto=5)         # Train last 6 layers only
        elif epoch < 4:
            freeze_encoder_layers(model, freeze_upto=3)         # Train last 8 layers
        else:
            freeze_encoder_layers(model, freeze_upto=-1)        # Train all layers

        for step, batch in enumerate(train_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['atten_masks'].to(device)
            labels = batch['labels']
            expressiveness = batch['expressiveness']
            hard_target = batch['hard_target'].to(device)
            soft_target = batch['soft_target'].to(device)

            with autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp):
                outputs = model(input_ids=input_ids,
                                attention_mask=attention_mask,
                                emotion_labels=labels,
                                expressiveness=expressiveness)

                logits_post = outputs['logits_post']
                logits_prior = outputs['logits_prior']
                mu_post, logvar_post = outputs['mu_post'], outputs['logvar_post']
                mu_prior, logvar_prior = outputs['mu_prior'], outputs['logvar_prior']

                # KL weight annealing schedule
                kl_weight = min(1.0, global_step / max(1, kl_anneal_steps))

                total_loss, _ = compute_loss(
                    logits_post, logits_prior,
                    mu_post, logvar_post, mu_prior, logvar_prior,
                    hard_target, soft_target,
                    epoch=epoch, lambda_kl=kl_weight
                )
                # Normalize loss for gradient accumulation
                total_loss = total_loss / gradient_accumulation_steps

            # Backward
            scaler.scale(total_loss).backward()

            if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
                # gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            epoch_loss += total_loss.item() * gradient_accumulation_steps 
            avg_loss = epoch_loss / len(train_dataloader)
            print(f"Epoch {epoch+1}/{epochs} | Avg Loss: {avg_loss:.4f}")
        
        # Validation
        metrics = validate(model, val_dataloader, device)
        micro_F1 = metrics.get("micro_f1", -1)
        macro_F1 = metrics.get("macro_f1", -1)
        print(f"Validation: micro-F1 = {micro_F1:.4f}, macro-F1 = {macro_F1:.4f}")

        if micro_F1 > best_val_microF1:
            best_val_microF1 = micro_F1
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join("model_dir", "best_model.pt"))                # ENTER MODEL DIR PATH FOR SAVING MODEL
            print("Micro-F1 score improved — model saved.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")

        # Early Stopping
        if epoch + 1 >= min_epochs_before_stop and epochs_no_improve >= early_stopping_patience:
            print("\nEarly stopping activated.")
            break

    print(f"\n\nTraining completed. Best validation micro-F1 = {best_val_microF1:.4f}")
    return model

In [None]:
# SET DEVICE 
train(model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, device=device)