In [23]:
import warnings
warnings.filterwarnings("ignore")

import os
import math
import json
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from transformers import AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup
from sklearn.metrics import f1_score


In [24]:
# DOWNLOAD PRETRAINED MODEL AND TOKENIZER
roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
roberta_base = AutoModel.from_pretrained("roberta-base")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device: ", device)

using device:  cuda


In [26]:
# Remove the pooler layer
roberta_base.pooler = None
roberta_base.gradient_checkpointing_enable()

In [27]:
roberta_base.config.hidden_size

768

In [28]:
train_path = "/content/emopillar_train_filtered.csv"
val_path = "/content/emopillar_validation_filtered.csv"
test_path = "/content/emopillar_test_filtered.csv"

df_train = pd.read_csv(train_path)
df_val = pd.read_csv(val_path)
df_test = pd.read_csv(test_path)

In [29]:
class EmoPillars_Dataset(Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer):
        self.tokenizer = tokenizer
        self.data = data
        self.max_len = 128
        self.target_cols = [str(i) for i in range(28)]

    def __len__(self):
        return(len(self.data))

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        text = str(item.utterance)
        encoding = self.tokenizer.encode_plus(text,
                                            add_special_tokens=True,
                                            truncation=True,
                                            return_tensors='pt',
                                            max_length=self.max_len,
                                            padding='max_length',
                                            return_attention_mask=True)

        target = torch.tensor(item[self.target_cols].values.astype('float32'))

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "atten_masks": encoding["attention_mask"].squeeze(0),
            "hard_target": target
        }

In [30]:
data = EmoPillars_Dataset(df_train, roberta_tokenizer)

In [31]:
len(data)

82750

In [32]:
data.__getitem__(1)

{'input_ids': tensor([    0,   100,   240,     7,   465,  7330,  2413, 10913,     8,   146,
           123,   582,    13,    39, 45492,   219,     4,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,   

In [33]:
# Data Loaders
train_dataloader = DataLoader(EmoPillars_Dataset(df_train, roberta_tokenizer), batch_size=64, num_workers=4, shuffle=True)
val_dataloader = DataLoader(EmoPillars_Dataset(df_val, roberta_tokenizer), batch_size=64, num_workers=4)
test_dataloader = DataLoader(EmoPillars_Dataset(df_test, roberta_tokenizer), batch_size=64, num_workers=4)

In [34]:
with open("/content/label_embeddings.json", "r") as f:
    label_embeddings = json.load(f)

emo_emb=[]
for k, v in label_embeddings.items():
    emb = torch.tensor(v, dtype=torch.float32)
    emo_emb.append(emb)

emotion_bank = torch.cat(emo_emb, dim=0).to(device)

In [35]:
class Encoder(nn.Module):
    """
    Text Encoder:
    - Takes in tokenized text (from tokenizer)
    - Generates the text embedding vector
    """
    def __init__(self, base_encoder):
        super().__init__()
        self.encoder = base_encoder

    def forward(self, inputs):
        """
        inputs: tokenizer output dict (input_ids, attention_mask)
        """
        outputs = self.encoder(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]                                                                   # [B, T, H]

        atten_mask = inputs['attention_mask']                                                                           # [B, T]
        # Mean pooling text_embeddings
        atten_mask = atten_mask.unsqueeze(-1).float()
        pooled_text_emb = (last_hidden_state * atten_mask).sum(dim=1) / atten_mask.sum(dim=1).clamp(min=1e-9)           # [B, H]
        pooled_text_emb = F.normalize(pooled_text_emb, p=2, dim=1)

        return {
            "text_emb": pooled_text_emb,
            "atten_mask": atten_mask
        }

In [36]:
class Classifier(nn.Module):
    def __init__(self, input_dim=768, num_classes=28):
        super().__init__()
        self.input_dim = input_dim

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, h):
        return self.mlp(h)

In [37]:
# Main Model class
class EmoAxis(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier


    def total_params(self):
        """Utility function to check trainable vs total params."""
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total parameters: {total:,}")
        print(f"Trainable parameters: {trainable:,}")


    def forward(self, input_ids, atten_mask):
        # Encoder
        outputs = self.encoder(
            inputs = {"input_ids": input_ids, "attention_mask": atten_mask}
            )

        # Classifier
        logits = self.classifier(outputs["text_emb"])

        return logits, outputs["text_emb"]


    def infer(self, input_ids, atten_mask):
        """
        Predict emotions from raw text (inference using prior p(z|x)).
        """
        self.eval()
        with torch.no_grad():
            logits, _ = self.forward(input_ids, atten_mask)
            return logits

In [38]:
encoder = Encoder(base_encoder=roberta_base)
classifier = Classifier()

# Initialize model
model = EmoAxis(
    encoder=encoder,
    classifier=classifier
)

model.total_params()

Total parameters: 124,464,156
Trainable parameters: 124,464,156


In [39]:
def contrastive_loss(text_emb, targets, temp=0.07, emo_bank=emotion_bank):
    # similarity
    sim_matrix = torch.matmul(text_emb, emo_bank.T)                    # [B, C]
    sim_matrix = sim_matrix / temp

    # log-softmax
    log_prob = F.log_softmax(sim_matrix, dim=1)

    # mask only true emotions
    mask = targets.bool()
    loss_con = -(log_prob * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

    # mean over batch
    return loss_con.mean()

In [40]:
def freeze_encoder_layers(encoder, freeze_upto: int=0):
    roberta_base_model = encoder.encoder

    for name, param in roberta_base_model.named_parameters():
        param.requires_grad = True

    if freeze_upto >= 0:
        for layer_idx in range(freeze_upto + 1):
            for param in roberta_base_model.encoder.layer[layer_idx].parameters():
                param.requires_grad = False

    print(f"\n\nFrozen encoder layers - 0 to {freeze_upto}\n\n")

In [41]:
def evaluate(model, dataloader, device, threshold=0.5):
    """
    Evaluate the model on validation set.
    Returns:
        - avg_val_loss: per-sample BCE loss
        - micro_f1, macro_f1: multi-label F1 scores
    """
    model.eval()
    loss_func = nn.BCEWithLogitsLoss()
    val_loss = 0.0
    total_samples = 0

    preds_all = []
    truths_all = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['atten_masks'].to(device)
            hard_target = batch['hard_target'].to(device)

            # Forward pass
            logits = model.infer(input_ids, attention_mask)
            loss = loss_func(logits, hard_target)

            batch_size = input_ids.size(0)
            val_loss += loss.item() * batch_size
            total_samples += batch_size

            # Predictions
            probs = torch.sigmoid(logits)
            preds = (probs >= threshold).int()

            preds_all.append(preds.cpu())
            truths_all.append(hard_target.cpu().int())

    # Concatenate all batches
    preds_all = torch.cat(preds_all, dim=0).numpy()
    truths_all = torch.cat(truths_all, dim=0).numpy()

    # Compute metrics
    avg_val_loss = val_loss / total_samples
    micro_f1 = f1_score(truths_all, preds_all, average='micro', zero_division=0)
    macro_f1 = f1_score(truths_all, preds_all, average='macro', zero_division=0)

    return {
        "avg_val_loss": avg_val_loss,
        "micro_f1": micro_f1,
        "macro_f1": macro_f1
    }

In [42]:
def train(
    model: torch.nn.Module,
    train_dataloader,
    val_dataloader,
    device: torch.device,
    epochs: int = 20,
    lr_encoder: float = 2.5e-5,
    lr_other: float = 1.5e-4,
    weight_decay: float = 0.001,
    warmup_ratio: float = 0.07,
    gradient_accumulation_steps: int = 4,
    max_grad_norm: float = 1.0,
    use_amp: bool = True,
    early_stopping_patience: int = 3,
    min_epochs_before_stop: int = 3
):
    model.to(device)

    steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    total_steps = steps_per_epoch * epochs
    warmup_steps = int(total_steps * warmup_ratio)

    BCE_loss = nn.BCEWithLogitsLoss()

    global_step = 0
    best_val_microF1 = -1.0
    epochs_no_improve = 0
    current_freeze_config = None

    # Keep track of losses and metrics for plotting
    history = {"train_loss": [], "val_loss": [], "micro_f1": [], "macro_f1": []}

    scaler = GradScaler(enabled=(use_amp and device.type == 'cuda'))

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        total_samples = 0

        # Progressive unfreezing
        if epoch < 1:
            freeze_level = 11
        elif epoch < 2:
            freeze_level = 7
        elif epoch < 4:
            freeze_level = 3
        else:
            freeze_level = -1

        # Reset optimizer/scheduler only if freeze config changed
        if freeze_level != current_freeze_config:
            freeze_encoder_layers(model.encoder, freeze_upto=freeze_level)
            current_freeze_config = freeze_level

            encoder_params, other_params = [], []
            for name, p in model.named_parameters():
                if not p.requires_grad:
                    continue
                if 'encoder' in name:
                    encoder_params.append(p)
                else:
                    other_params.append(p)

            optimizer = AdamW([
                {"params": encoder_params, "lr": lr_encoder},
                {"params": other_params, "lr": lr_other}
            ], weight_decay=weight_decay)

            scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps
            )

        optimizer.zero_grad()

        for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['atten_masks'].to(device)
            hard_target = batch['hard_target'].to(device)

            batch_size = input_ids.size(0)

            with autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
                logits, text_emb = model(input_ids=input_ids, atten_mask=attention_mask)
                loss_con = contrastive_loss(text_emb, hard_target)
                loss_bce = BCE_loss(logits, hard_target)
                loss = 0.5 * loss_bce + 0.5 * loss_con
                total_loss = loss / gradient_accumulation_steps

            scaler.scale(total_loss).backward()

            if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()
                optimizer.zero_grad()
                global_step += 1

            # Accumulate epoch loss per sample
            epoch_loss += loss.item() * batch_size
            total_samples += batch_size

        avg_train_loss = epoch_loss / total_samples
        history["train_loss"].append(avg_train_loss)
        print(f"> | Epoch {epoch+1}/{epochs} | Avg Train Loss: {avg_train_loss:.4f}")
        print(f"\nBCE: {loss_bce.item():.4f}, Contrastive: {loss_con.item():.4f}\n")

        # Validation
        val_metrics = evaluate(model, val_dataloader, device, threshold=0.2)
        avg_val_loss = val_metrics.get("avg_val_loss", None)
        micro_F1 = val_metrics.get("micro_f1", -1)
        macro_F1 = val_metrics.get("macro_f1", -1)

        history["val_loss"].append(avg_val_loss)
        history["micro_f1"].append(micro_F1)
        history["macro_f1"].append(macro_F1)

        print(f"Validation | Avg Loss: {avg_val_loss:.4f}, Micro-F1: {micro_F1:.4f}, Macro-F1: {macro_F1:.4f}\n")

        # Early stopping
        if micro_F1 > best_val_microF1:
            best_val_microF1 = micro_F1
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join("/content/model", "best_model.pt"))
            print("Micro-F1 improved — model saved.\n")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")

        if epoch + 1 >= min_epochs_before_stop and epochs_no_improve >= early_stopping_patience:
            print("\nEarly stopping activated.")
            break
        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()

    print(f"\nTraining completed. Best validation micro-F1 = {best_val_microF1:.4f}")
    return model, history


In [43]:
trained_model, training_history = train(model, train_dataloader, val_dataloader, device)




Frozen encoder layers - 0 to 11




100%|██████████| 1293/1293 [06:48<00:00,  3.16it/s]

> | Epoch 1/20 | Avg Train Loss: 2.0052

BCE: 0.2408, Contrastive: 3.6059






Validation | Avg Loss: 0.2617, Micro-F1: 0.5474, Macro-F1: 0.3285

Micro-F1 improved — model saved.



Frozen encoder layers - 0 to 7




100%|██████████| 1293/1293 [07:22<00:00,  2.92it/s]

> | Epoch 2/20 | Avg Train Loss: 1.2880

BCE: 0.1039, Contrastive: 1.6716






Validation | Avg Loss: 0.1629, Micro-F1: 0.6917, Macro-F1: 0.6036

Micro-F1 improved — model saved.



Frozen encoder layers - 0 to 3




100%|██████████| 1293/1293 [07:55<00:00,  2.72it/s]

> | Epoch 3/20 | Avg Train Loss: 0.8745

BCE: 0.0923, Contrastive: 1.6254






Validation | Avg Loss: 0.1551, Micro-F1: 0.7209, Macro-F1: 0.6525

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [07:55<00:00,  2.72it/s]

> | Epoch 4/20 | Avg Train Loss: 0.8121

BCE: 0.0679, Contrastive: 1.4559






Validation | Avg Loss: 0.1622, Micro-F1: 0.7304, Macro-F1: 0.6650

Micro-F1 improved — model saved.



Frozen encoder layers - 0 to -1




100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 5/20 | Avg Train Loss: 0.7692

BCE: 0.0615, Contrastive: 1.4076






Validation | Avg Loss: 0.1616, Micro-F1: 0.7438, Macro-F1: 0.6779

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 6/20 | Avg Train Loss: 0.7556

BCE: 0.0721, Contrastive: 1.4422






Validation | Avg Loss: 0.1659, Micro-F1: 0.7466, Macro-F1: 0.6781

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 7/20 | Avg Train Loss: 0.7301

BCE: 0.0635, Contrastive: 1.4805






Validation | Avg Loss: 0.1720, Micro-F1: 0.7519, Macro-F1: 0.6843

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 8/20 | Avg Train Loss: 0.7070

BCE: 0.0551, Contrastive: 1.3438






Validation | Avg Loss: 0.1759, Micro-F1: 0.7569, Macro-F1: 0.6900

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 9/20 | Avg Train Loss: 0.6876

BCE: 0.0567, Contrastive: 1.4148






Validation | Avg Loss: 0.1857, Micro-F1: 0.7579, Macro-F1: 0.6892

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 10/20 | Avg Train Loss: 0.6703

BCE: 0.0397, Contrastive: 1.3103






Validation | Avg Loss: 0.1905, Micro-F1: 0.7597, Macro-F1: 0.6926

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 11/20 | Avg Train Loss: 0.6565

BCE: 0.0511, Contrastive: 1.2931






Validation | Avg Loss: 0.1988, Micro-F1: 0.7606, Macro-F1: 0.6916

Micro-F1 improved — model saved.



100%|██████████| 1293/1293 [08:28<00:00,  2.54it/s]

> | Epoch 12/20 | Avg Train Loss: 0.6433

BCE: 0.0295, Contrastive: 1.1953






Validation | Avg Loss: 0.2036, Micro-F1: 0.7623, Macro-F1: 0.6931

Micro-F1 improved — model saved.



  2%|▏         | 30/1293 [00:11<08:20,  2.52it/s]


KeyboardInterrupt: 

In [58]:
model = EmoAxis(encoder=encoder,
    classifier=classifier
)

# 2️⃣ Load weights into model
state_dict = torch.load("/content/model/best_model.pt", map_location=device)
model.load_state_dict(state_dict)

# 3️⃣ Move to device
model.to(device)

# 4️⃣ Set eval mode
model.eval()

EmoAxis(
  (encoder): Encoder(
    (encoder): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): Linear(in_features=768, o

In [59]:
from sklearn.metrics import precision_score, recall_score

def evaluate(model, dataloader, device, threshold=0.5):
    """
    Evaluate the model on validation set.
    Returns:
        - avg_val_loss: per-sample BCE loss
        - micro_f1, macro_f1: multi-label F1 scores
    """
    preds_all = []
    truths_all = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['atten_masks'].to(device)
            hard_target = batch['hard_target'].to(device)

            # Forward pass
            logits = model.infer(input_ids, attention_mask)

            # Predictions
            probs = torch.sigmoid(logits)
            preds = (probs >= threshold).int()

            preds_all.append(preds.cpu())
            truths_all.append(hard_target.cpu().int())

    # Concatenate all batches
    preds_all = torch.cat(preds_all, dim=0).numpy()
    truths_all = torch.cat(truths_all, dim=0).numpy()

    # Compute metrics
    micro_precission = precision_score(truths_all, preds_all, average='micro', zero_division=0)
    macro_precission = precision_score(truths_all, preds_all, average='macro', zero_division=0)
    micro_recall = recall_score(truths_all, preds_all, average='micro', zero_division=0)
    macro_recall = recall_score(truths_all, preds_all, average='macro', zero_division=0)
    micro_f1 = f1_score(truths_all, preds_all, average='micro', zero_division=0)
    macro_f1 = f1_score(truths_all, preds_all, average='macro', zero_division=0)

    print(f"micro_precission: {micro_precission}\n macro_precission: {macro_precission}\n micro_recall: {micro_recall}\n")
    print(f"micro_f1: {micro_f1}\n macro_f1: {macro_f1}")

In [60]:
evaluate(model, test_dataloader, device)

micro_precission: 0.8021522372224448
 macro_precission: 0.7906524102331024
 micro_recall: 0.7299685723574673

micro_f1: 0.7643599900965585
 macro_f1: 0.6849397272644077
