# AURA V11.6 — DB-MTL with Scale Clamping & EMA Smoothing

**Base:** V11.5 (DB-MTL + all V11 fixes retained)

## New in V11.6
- **Scale clamp (5.0)** — no task can exceed 5× the gradient budget of the strongest
- **EMA-smoothed gradient norms (β=0.9)** — stabilizes per-step noise
  - Inspired by LibMTL official DB-MTL implementation
  - Inactive tasks retain their previous EMA value (no decay)

## Why V11.6?
V11.5 showed DB-MTL works (Emotion F1 0.6178 → best, Sentiment F1 0.9484 → best)
but Reporting's scale factor exploded to 2×10¹¹, causing:
- Stress test regression: 10/11 → 6/11 (false positive Toxic predictions)
- Toxicity F1 regression: 0.7836 → 0.7758

Scale clamping + EMA directly address both failure modes.

## Retained from V11.5
- DB-MTL dual balancing (log-loss + gradient normalization)
- No learnable task weights
- All V11 fixes (task mask, optimizer reset, official splits)


In [None]:
# Cell 1: Imports & Seed — IDENTICAL TO V10.2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from transformers import RobertaModel, RobertaTokenizer, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
from sklearn.metrics import (
    f1_score, classification_report, confusion_matrix, 
    multilabel_confusion_matrix, precision_recall_fscore_support
)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Reproducibility — SAME SEED AS V10.2
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\U0001f527 Device: {device}')
if device.type == 'cuda':
    print(f'   GPU: {torch.cuda.get_device_name(0)}')
    print(f'   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Cell 2: Configuration — V11.6: DB-MTL + Clamp + EMA
CONFIG = {
    # Model
    'encoder': 'roberta-base',
    'hidden_dim': 768,
    'n_heads': 8,
    'num_emotion_classes': 7,
    'max_length': 128,
    'dropout': 0.3,
    
    # Training
    'batch_size': 16,
    'gradient_accumulation': 4,  # Effective batch = 64
    'epochs': 10,  # SAME AS V10.2 FINAL RUN
    'lr_encoder': 1e-5,
    'lr_heads': 5e-5,
    'weight_decay': 0.01,
    'max_grad_norm': 1.0,
    'warmup_ratio': 0.1,
    
    # Regularization (Module 3)
    'focal_gamma': 2.0,
    'label_smoothing': 0.1,
    'patience': 5,
    'freezing_epochs': 1,
    
    # V11.6: DB-MTL stabilization
    'balancing': 'db-mtl',
    'db_scale_clamp': 5.0,   # Max scale factor any task can receive
    'db_ema_beta': 0.9,      # EMA decay for gradient norm smoothing
}

DATA_DIR = '/kaggle/input/aura-v11-data'
EMO_COLS = ['anger', 'disgust', 'fear', 'joy', 'sadness', 'surprise', 'neutral']

print('\U0001f4cb V11.6 Configuration (DB-MTL + Clamp + EMA):')
for k, v in CONFIG.items():
    print(f'   {k}: {v}')


In [None]:
# Cell 3: Visualization Functions — IDENTICAL TO V10.2
def plot_class_distribution(df, label_col, title, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    counts = df[label_col].value_counts().sort_index()
    bars = ax.bar(counts.index.astype(str), counts.values, color=['#66c2a5', '#fc8d62'])
    ax.set_title(title)
    ax.set_xlabel('Class')
    ax.set_ylabel('Count')
    for bar, count in zip(bars, counts.values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50, 
                str(count), ha='center', fontsize=10)
    return ax

def plot_confusion_matrix_heatmap(y_true, y_pred, labels, title='Confusion Matrix', ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 5))
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=labels, yticklabels=labels, ax=ax,
                cbar_kws={'label': 'Count'})
    ax.set_title(title)
    ax.set_ylabel('Actual')
    ax.set_xlabel('Predicted')
    return ax

def plot_multilabel_confusion_matrices(y_true, y_pred, labels, normalize=True):
    cms = multilabel_confusion_matrix(y_true, y_pred)
    n_labels = len(labels)
    cols = min(4, n_labels)
    rows = (n_labels + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))
    axes = axes.flatten() if n_labels > 1 else [axes]
    
    for i, (cm, label) in enumerate(zip(cms, labels)):
        ax = axes[i]
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
            fmt = '.2f'
        else:
            fmt = 'd'
        sns.heatmap(cm, annot=True, fmt=fmt, cmap='YlGnBu', ax=ax,
                    xticklabels=['Neg', 'Pos'], yticklabels=['Neg', 'Pos'],
                    vmin=0, vmax=1 if normalize else None, cbar=False)
        ax.set_title(label, fontsize=10)
        ax.set_ylabel('Actual')
        ax.set_xlabel('Predicted')
    
    for i in range(n_labels, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('Multilabel Confusion Matrices (Normalized)', fontsize=12)
    plt.tight_layout()
    plt.show()

def plot_training_history(history):
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].plot(range(1, len(history['train_loss'])+1), history['train_loss'], 'b-o', label='Train')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(range(1, len(history['val_f1'])+1), history['val_f1'], 'g-o', label='Val F1')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Macro F1')
    axes[1].set_title('Validation F1 Score')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    weights = np.array(history['task_weights'])
    for i, name in enumerate(['Toxicity', 'Emotion', 'Sentiment', 'Reporting']):
        axes[2].plot(range(1, len(weights)+1), weights[:, i], '-o', label=name)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Weight')
    axes[2].set_title('DB-MTL Effective Weights (clamped + EMA)')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print('\U0001f4ca Visualization functions loaded.')

In [None]:
# Cell 4: Task-Specific Multi-Head Attention Module — IDENTICAL TO V10.2
class TaskSpecificMHA(nn.Module):
    """Multi-Head Self-Attention per task (Module 2: Redundancy Principle).
    
    Each task gets its own attention mechanism to learn WHERE to look.
    - Toxicity: looks for 'You' + insults
    - Reporting: looks for 'said', 'claims'
    - Sentiment: looks for adjectives
    """
    def __init__(self, hidden_dim, n_heads, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(
            embed_dim=hidden_dim, 
            num_heads=n_heads, 
            batch_first=True, 
            dropout=dropout
        )
        self.layernorm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, hidden_states, attention_mask):
        key_padding_mask = (attention_mask == 0)
        attn_output, attn_weights = self.mha(
            query=hidden_states, 
            key=hidden_states, 
            value=hidden_states,
            key_padding_mask=key_padding_mask
        )
        output = self.layernorm(hidden_states + self.dropout(attn_output))
        return output, attn_weights

print('\U0001f9e0 TaskSpecificMHA module defined.')

In [None]:
# Cell 5: AURA V10 Model — shared_rep for DB-MTL gradient balancing (V11.3+)
class AURA_V10(nn.Module):
    """AURA V10: RoBERTa + 4 Parallel Task-Specific MHSA Blocks."""
    
    def __init__(self, config):
        super().__init__()
        self.roberta = RobertaModel.from_pretrained(config['encoder'])
        hidden = config['hidden_dim']
        
        # 4 Parallel MHSA Blocks (Feature Disentanglement)
        self.tox_mha = TaskSpecificMHA(hidden, config['n_heads'], config['dropout'])
        self.emo_mha = TaskSpecificMHA(hidden, config['n_heads'], config['dropout'])
        self.sent_mha = TaskSpecificMHA(hidden, config['n_heads'], config['dropout'])
        self.report_mha = TaskSpecificMHA(hidden, config['n_heads'], config['dropout'])
        
        self.dropout = nn.Dropout(config['dropout'])
        
        # Classification Heads
        self.toxicity_head = nn.Linear(hidden, 2)
        self.emotion_head = nn.Linear(hidden, config['num_emotion_classes'])
        self.sentiment_head = nn.Linear(hidden, 2)
        self.reporting_head = nn.Linear(hidden, 1)
        
        # Bias Initialization (NB11: Imbalanced Datasets)
        with torch.no_grad():
            self.toxicity_head.bias[0] = 2.5   # Non-Toxic
            self.toxicity_head.bias[1] = -2.5  # Toxic

    def _mean_pool(self, seq, mask):
        mask_exp = mask.unsqueeze(-1).expand(seq.size()).float()
        return (seq * mask_exp).sum(dim=1) / mask_exp.sum(dim=1).clamp(min=1e-9)

    def forward(self, input_ids, attention_mask):
        shared = self.roberta(input_ids, attention_mask).last_hidden_state
        
        tox_seq, _ = self.tox_mha(shared, attention_mask)
        emo_seq, _ = self.emo_mha(shared, attention_mask)
        sent_seq, _ = self.sent_mha(shared, attention_mask)
        rep_seq, _ = self.report_mha(shared, attention_mask)
        
        return {
            'toxicity': self.toxicity_head(self.dropout(self._mean_pool(tox_seq, attention_mask))),
            'emotion': self.emotion_head(self.dropout(self._mean_pool(emo_seq, attention_mask))),
            'sentiment': self.sentiment_head(self.dropout(self._mean_pool(sent_seq, attention_mask))),
            'reporting': self.reporting_head(self.dropout(self._mean_pool(rep_seq, attention_mask))).squeeze(-1),
            'shared_rep': shared  # V11.5: encoder output for DB-MTL gradient balancing
        }

print('\U0001f985 AURA_V10 model defined.')

In [None]:
# Cell 6: Loss Functions
# V11.6: DB-MTL with scale clamping and EMA-smoothed gradient norms.
#
# DB-MTL: Dual-Balancing Multi-Task Learning (Lin et al., 2023)
# Two mechanisms to prevent any task from dominating:
#   1. Log-transform each task loss -> normalizes loss scales
#   2. Normalize gradient magnitudes -> all tasks contribute equally to encoder
#
# V11.6 additions (fixes V11.5 Reporting explosion):
#   3. Scale clamp -> no task can exceed scale_clamp x the gradient budget
#   4. EMA smoothing -> gradient norms are exponentially averaged for stability
#
# No learnable parameters -> no divergence possible.

def focal_loss(logits, targets, weight=None, gamma=2.0, smoothing=0.0):
    """Focal loss for binary/multiclass classification — IDENTICAL to V10.2."""
    n_classes = logits.size(-1)
    if smoothing > 0:
        with torch.no_grad():
            smooth = torch.full_like(logits, smoothing / (n_classes - 1))
            smooth.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)
        log_probs = F.log_softmax(logits, dim=-1)
        ce = -(smooth * log_probs).sum(dim=-1)
    else:
        ce = F.cross_entropy(logits, targets, weight=weight, reduction='none')
    
    pt = torch.exp(-ce)
    focal = ((1 - pt) ** gamma) * ce
    return focal.mean()


class DBMTLLoss(nn.Module):
    """DB-MTL with Scale Clamping & EMA (V11.6).
    
    Extends DB-MTL (Lin et al., 2023) with two stabilization mechanisms:
    
    1. Log-loss scale balancing (original DB-MTL)
    2. Gradient-magnitude balancing (original DB-MTL)
    3. Scale clamping: prevents any task from exceeding scale_clamp x
       the gradient budget of the strongest task
    4. EMA smoothing: exponentially averages gradient norms across steps,
       reducing per-step volatility (inspired by LibMTL implementation)
    
    Inactive tasks (not in current batch) retain their previous EMA value.
    No learnable parameters -> no divergence possible.
    """
    def __init__(self, n_tasks, ema_beta=0.9, scale_clamp=5.0):
        super().__init__()
        self.n_tasks = n_tasks
        self.ema_beta = ema_beta
        self.scale_clamp = scale_clamp
        # EMA-smoothed gradient norms per task (None until first unfrozen step)
        self._grad_norm_ema = None
        # Track effective weights for logging (not learned, just observed)
        self._last_effective_weights = np.ones(n_tasks)
    
    def forward(self, task_losses, shared_rep, task_mask=None):
        """Compute DB-MTL balanced loss with clamp + EMA.
        
        Args:
            task_losses: list of per-task loss tensors
            shared_rep: encoder output tensor (for gradient computation)
            task_mask: list of bools — which tasks are present in this batch
        
        Returns:
            Balanced total loss (scalar tensor)
        """
        # Collect active (present in batch) task losses
        active_losses = []
        active_indices = []
        for i in range(self.n_tasks):
            if task_mask is not None and not task_mask[i]:
                continue
            if not task_losses[i].requires_grad:
                continue
            active_losses.append(task_losses[i])
            active_indices.append(i)
        
        if len(active_losses) == 0:
            return torch.tensor(0.0, device=shared_rep.device, requires_grad=True)
        
        # If only one task in batch, just return its log-loss (no balancing needed)
        if len(active_losses) == 1:
            return torch.log(active_losses[0].clamp(min=1e-8))
        
        # ── Step 1: Log-transform losses (loss-scale balancing) ──
        log_losses = [torch.log(L.clamp(min=1e-8)) for L in active_losses]
        
        # ── Step 2: Gradient-magnitude balancing with EMA + Clamp ──
        # Only possible when shared_rep has gradient (encoder unfrozen).
        # During frozen epochs, shared_rep.requires_grad=False, so we
        # fall back to equal-weight log-loss sum (log balancing only).
        can_balance_grads = shared_rep.requires_grad
        
        if can_balance_grads:
            # Compute raw per-task gradient norms w.r.t. shared_rep
            raw_grad_norms = []
            for ll in log_losses:
                grad = torch.autograd.grad(
                    ll, shared_rep, retain_graph=True, create_graph=False,
                    allow_unused=True
                )[0]
                if grad is not None:
                    raw_grad_norms.append(grad.norm().item())
                else:
                    raw_grad_norms.append(0.0)
            
            # V11.6: EMA-smooth the gradient norms
            if self._grad_norm_ema is None:
                # First unfrozen step: initialize EMA with raw norms
                self._grad_norm_ema = np.zeros(self.n_tasks)
                for j, idx in enumerate(active_indices):
                    self._grad_norm_ema[idx] = raw_grad_norms[j]
            else:
                # Update EMA for active tasks only
                # Inactive tasks retain their previous EMA value
                for j, idx in enumerate(active_indices):
                    self._grad_norm_ema[idx] = (
                        self.ema_beta * self._grad_norm_ema[idx]
                        + (1 - self.ema_beta) * raw_grad_norms[j]
                    )
            
            # Use smoothed norms for scale computation
            smoothed = [self._grad_norm_ema[idx] for idx in active_indices]
            max_norm = max(smoothed) if max(smoothed) > 1e-12 else 1.0
            scales = [max_norm / max(gn, 1e-12) for gn in smoothed]
            
            # V11.6: Clamp scale factors to prevent explosion
            scales = [min(s, self.scale_clamp) for s in scales]
        else:
            # Frozen encoder: equal weights (log-loss balancing only)
            scales = [1.0] * len(log_losses)
        
        # ── Step 3: Weighted sum of log-losses ──
        # Each log-loss is scaled so its gradient magnitude matches the max,
        # clamped to prevent any task from dominating.
        total = torch.tensor(0.0, device=shared_rep.device)
        for ll, s in zip(log_losses, scales):
            total = total + s * ll
        
        # Track effective weights for visualization
        eff_weights = np.ones(self.n_tasks)
        for j, idx in enumerate(active_indices):
            eff_weights[idx] = scales[j]
        self._last_effective_weights = eff_weights
        
        return total
    
    def get_weights(self):
        """Return last observed effective weights (for logging/viz only)."""
        return self._last_effective_weights.copy()


print('\u2696\ufe0f Loss functions defined (Focal + DB-MTL).')
print('   V11.6: DB-MTL with scale clamping and EMA smoothing.')
print('   (1) log-loss scale balancing, (2) gradient magnitude normalization,')
print(f'   (3) scale clamp at {CONFIG["db_scale_clamp"]}, '
      f'(4) EMA beta={CONFIG["db_ema_beta"]}.')
print('   No learnable task weights. V11 FIX retained: task_mask.')


In [None]:
# Cell 7: Dataset Classes — IDENTICAL TO V10.2
class BaseDataset(Dataset):
    def __init__(self, path, tokenizer, max_len):
        self.df = pd.read_csv(path)
        self.tok = tokenizer
        self.max_len = max_len
        
    def __len__(self): 
        return len(self.df)
    
    def encode(self, text):
        return self.tok(
            str(text), max_length=self.max_len, 
            padding='max_length', truncation=True, return_tensors='pt'
        )

class ToxicityDataset(BaseDataset):
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.encode(row['text'])
        return {
            'ids': enc['input_ids'].flatten(), 
            'mask': enc['attention_mask'].flatten(),
            'tox': torch.tensor(int(row['label']), dtype=torch.long), 
            'task': 0
        }

class EmotionDataset(BaseDataset):
    def __init__(self, path, tokenizer, max_len, cols):
        super().__init__(path, tokenizer, max_len)
        self.cols = cols
        if 'label_sum' in self.df.columns:
            self.df = self.df[self.df['label_sum'] > 0].reset_index(drop=True)
            
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.encode(row['text'])
        return {
            'ids': enc['input_ids'].flatten(), 
            'mask': enc['attention_mask'].flatten(),
            'emo': torch.tensor([float(row[c]) for c in self.cols], dtype=torch.float), 
            'task': 1
        }

class SentimentDataset(BaseDataset):
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.encode(row['text'])
        return {
            'ids': enc['input_ids'].flatten(), 
            'mask': enc['attention_mask'].flatten(),
            'sent': torch.tensor(int(row['label']), dtype=torch.long), 
            'task': 2
        }

class ReportingDataset(BaseDataset):
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.encode(row['text'])
        return {
            'ids': enc['input_ids'].flatten(), 
            'mask': enc['attention_mask'].flatten(),
            'rep': torch.tensor(int(row['is_reporting']), dtype=torch.long), 
            'task': 3
        }

def collate_fn(batch):
    ids = torch.stack([x['ids'] for x in batch])
    mask = torch.stack([x['mask'] for x in batch])
    tasks = torch.tensor([x['task'] for x in batch])
    
    tox_items = [x['tox'] for x in batch if x['task'] == 0]
    emo_items = [x['emo'] for x in batch if x['task'] == 1]
    sent_items = [x['sent'] for x in batch if x['task'] == 2]
    rep_items = [x['rep'] for x in batch if x['task'] == 3]
    
    return {
        'ids': ids, 'mask': mask, 'tasks': tasks,
        'tox': torch.stack(tox_items) if tox_items else None,
        'emo': torch.stack(emo_items) if emo_items else None,
        'sent': torch.stack(sent_items) if sent_items else None,
        'rep': torch.stack(rep_items) if rep_items else None
    }

print('\U0001f4e6 Dataset classes defined \u2014 identical to V10.2.')

In [None]:
# Cell 8: Load Data
# V11 FIX #1: Load proper held-out validation sets for emotion and sentiment.
# These come from official GoEmotions dev and SST-2 dev splits,
# generated by prepare_v11_datasets.py. No data leak.

tokenizer = RobertaTokenizer.from_pretrained(CONFIG['encoder'])

# Training sets
tox_train = ToxicityDataset(f'{DATA_DIR}/toxicity_train.csv', tokenizer, CONFIG['max_length'])
emo_train = EmotionDataset(f'{DATA_DIR}/emotions_train.csv', tokenizer, CONFIG['max_length'], EMO_COLS)
sent_train = SentimentDataset(f'{DATA_DIR}/sentiment_train.csv', tokenizer, CONFIG['max_length'])
rep_train = ReportingDataset(f'{DATA_DIR}/reporting_examples_augmented.csv', tokenizer, CONFIG['max_length'])

# Validation sets — V11: ALL from official held-out splits
tox_val = ToxicityDataset(f'{DATA_DIR}/toxicity_val.csv', tokenizer, CONFIG['max_length'])
emo_val = EmotionDataset(f'{DATA_DIR}/emotions_val.csv', tokenizer, CONFIG['max_length'], EMO_COLS)
sent_val = SentimentDataset(f'{DATA_DIR}/sentiment_val.csv', tokenizer, CONFIG['max_length'])

# Combined training loader
train_ds = ConcatDataset([tox_train, emo_train, sent_train, rep_train])
train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, 
                          collate_fn=collate_fn, num_workers=2, pin_memory=True)

# Separate validation loaders per task
tox_val_loader = DataLoader(tox_val, batch_size=CONFIG['batch_size'], collate_fn=collate_fn)
emo_val_loader = DataLoader(emo_val, batch_size=CONFIG['batch_size'], collate_fn=collate_fn)
sent_val_loader = DataLoader(sent_val, batch_size=CONFIG['batch_size'], collate_fn=collate_fn)

print('='*60)
print('\U0001f4ca DATASET SUMMARY')
print('='*60)
print(f'Training Samples: {len(train_ds):,}')
print(f'  \u251c\u2500 Toxicity:  {len(tox_train):,}')
print(f'  \u251c\u2500 Emotion:   {len(emo_train):,}')
print(f'  \u251c\u2500 Sentiment: {len(sent_train):,}')
print(f'  \u2514\u2500 Reporting: {len(rep_train):,}')
print(f'Validation Samples:')
print(f'  \u251c\u2500 Toxicity:  {len(tox_val):,} (OLID official test)')
print(f'  \u251c\u2500 Emotion:   {len(emo_val):,} (GoEmotions official dev)')
print(f'  \u2514\u2500 Sentiment: {len(sent_val):,} (SST-2 official dev)')

In [None]:
# Cell 9: Data Distribution Analysis (NB11 Pattern)
print('='*60)
print('\U0001f4c8 CLASS DISTRIBUTION ANALYSIS (NB11)')
print('='*60)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 1. Toxicity Distribution
tox_df = pd.read_csv(f'{DATA_DIR}/toxicity_train.csv')
plot_class_distribution(tox_df, 'label', 'Toxicity: Class Distribution', axes[0, 0])
axes[0, 0].set_xticklabels(['Non-Toxic (0)', 'Toxic (1)'])

# 2. Task Sample Distribution
task_counts = {'Toxicity': len(tox_train), 'Emotion': len(emo_train), 
               'Sentiment': len(sent_train), 'Reporting': len(rep_train)}
colors = ['#66c2a5', '#fc8d62', '#8da0cb', '#e78ac3']
bars = axes[0, 1].bar(task_counts.keys(), task_counts.values(), color=colors)
axes[0, 1].set_title('Task Sample Distribution')
axes[0, 1].set_ylabel('Count')
for bar, count in zip(bars, task_counts.values()):
    axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 500, 
                    f'{count:,}', ha='center', fontsize=9)

# 3. Emotion Label Distribution (Multilabel)
emo_df = pd.read_csv(f'{DATA_DIR}/emotions_train.csv')
if 'label_sum' in emo_df.columns:
    emo_df = emo_df[emo_df['label_sum'] > 0]
emo_counts = emo_df[EMO_COLS].sum().sort_values(ascending=True)
emo_counts.plot(kind='barh', ax=axes[1, 0], color='#8da0cb')
axes[1, 0].set_title('Emotion Label Distribution')
axes[1, 0].set_xlabel('Count')

# 4. # of Labels per Sample
if 'label_sum' in emo_df.columns:
    label_counts = emo_df['label_sum'].value_counts().sort_index()
else:
    label_counts = emo_df[EMO_COLS].sum(axis=1).value_counts().sort_index()
label_counts.plot(kind='bar', ax=axes[1, 1], color='#fc8d62')
axes[1, 1].set_title('Emotion: # of Labels per Sample')
axes[1, 1].set_xlabel('Number of Emotion Labels')
axes[1, 1].set_ylabel('Count')
axes[1, 1].tick_params(axis='x', rotation=0)

plt.tight_layout()
plt.show()

neg, pos = tox_df['label'].value_counts().sort_index()
print(f'\n\u26a0\ufe0f Toxicity Imbalance: {neg:,} Non-Toxic vs {pos:,} Toxic ({pos/(neg+pos)*100:.1f}% minority class)')

In [None]:
# Cell 10: Model & Optimizer Setup
# V11.6: DBMTLLoss with EMA + clamp hyperparameters from CONFIG

model = AURA_V10(CONFIG).to(device)

# V11.6: DB-MTL loss with scale clamping and EMA smoothing
loss_fn = DBMTLLoss(
    n_tasks=4,
    ema_beta=CONFIG['db_ema_beta'],
    scale_clamp=CONFIG['db_scale_clamp']
).to(device)

# Separate encoder and head parameters
encoder_params = set(model.roberta.parameters())
head_params = [p for p in model.parameters() if p not in encoder_params]

optimizer = torch.optim.AdamW([
    {'params': model.roberta.parameters(), 'lr': CONFIG['lr_encoder']},
    {'params': head_params, 'lr': CONFIG['lr_heads']}
], weight_decay=CONFIG['weight_decay'])

# Toxicity class weights — IDENTICAL to V10.2/V11
tox_weights = torch.tensor([0.5, 2.0]).to(device)

total_steps = (len(train_loader) // CONFIG['gradient_accumulation']) * CONFIG['epochs']
warmup_steps = int(total_steps * CONFIG['warmup_ratio'])

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

n_params = sum(p.numel() for p in model.parameters())
n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)

print('='*60)
print('\U0001f3d7\ufe0f MODEL SETUP')
print('='*60)
print(f'Encoder: {CONFIG["encoder"]}')
print(f'Total parameters:     {n_params:,}')
print(f'Trainable parameters: {n_train:,}')
print(f'Total optimization steps: {total_steps}')
print(f'Warmup steps: {warmup_steps}')
print(f'Effective batch size: {CONFIG["batch_size"] * CONFIG["gradient_accumulation"]}')
print(f'V11.6: DB-MTL + clamp={CONFIG["db_scale_clamp"]} + EMA beta={CONFIG["db_ema_beta"]}')


In [None]:
# Cell 11: Training & Evaluation Functions
# V11.6: DB-MTL with scale clamping and EMA-smoothed gradient norms.
# All V11 fixes retained: task_mask, optimizer state reset.
# Key difference from V11.5: DBMTLLoss now internally handles
# EMA smoothing and scale clamping — no changes needed in train_epoch.

def train_epoch(epoch):
    model.train()
    
    # Progressive freezing — IDENTICAL to V11
    if epoch <= CONFIG['freezing_epochs']:
        print(f'\u2744\ufe0f Epoch {epoch}: RoBERTa FROZEN')
        for p in model.roberta.parameters(): 
             p.requires_grad = False
    else:
        # --- V11 FIX #4: Reset Adam states on first unfreeze ---
        if epoch == CONFIG['freezing_epochs'] + 1:
            print(f'\U0001f525 Epoch {epoch}: RoBERTa UNFROZEN (resetting optimizer states)')
            for p in model.roberta.parameters():
                p.requires_grad = True
                # Clear stale Adam momentum/variance from frozen epochs
                if p in optimizer.state:
                    del optimizer.state[p]
        else:
            print(f'\U0001f525 Epoch {epoch}: RoBERTa UNFROZEN')
            for p in model.roberta.parameters(): 
                 p.requires_grad = True
    
    total_loss = 0
    step_weights = []  # Track DB-MTL effective weights per step
    optimizer.zero_grad()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}', mininterval=10.0)
    
    for step, batch in enumerate(pbar):
        ids = batch['ids'].to(device)
        mask = batch['mask'].to(device)
        tasks = batch['tasks']
        
        out = model(ids, mask)
        shared_rep = out['shared_rep']  # For DB-MTL gradient balancing
        
        # Compute per-task losses — IDENTICAL to V10.2/V11
        losses = []
        task_mask = []  # V11 FIX #3: track which tasks are present
        
        # Toxicity
        if batch['tox'] is not None and (tasks == 0).sum() > 0:
            losses.append(focal_loss(
                out['toxicity'][tasks == 0], batch['tox'].to(device), 
                weight=tox_weights, smoothing=CONFIG['label_smoothing']
            ))
            task_mask.append(True)
        else: 
            losses.append(torch.tensor(0., device=device, requires_grad=False))
            task_mask.append(False)
            
        # Emotion (Multilabel BCE)
        if batch['emo'] is not None and (tasks == 1).sum() > 0:
            losses.append(F.binary_cross_entropy_with_logits(
                out['emotion'][tasks == 1], batch['emo'].to(device)
            ))
            task_mask.append(True)
        else: 
            losses.append(torch.tensor(0., device=device, requires_grad=False))
            task_mask.append(False)
            
        # Sentiment
        if batch['sent'] is not None and (tasks == 2).sum() > 0:
            losses.append(focal_loss(
                out['sentiment'][tasks == 2], batch['sent'].to(device), 
                smoothing=CONFIG['label_smoothing']
            ))
            task_mask.append(True)
        else: 
            losses.append(torch.tensor(0., device=device, requires_grad=False))
            task_mask.append(False)
            
        # Reporting
        if batch['rep'] is not None and (tasks == 3).sum() > 0:
            losses.append(F.binary_cross_entropy_with_logits(
                out['reporting'][tasks == 3], batch['rep'].float().to(device)
            ))
            task_mask.append(True)
        else: 
            losses.append(torch.tensor(0., device=device, requires_grad=False))
            task_mask.append(False)
            
        # Skip fully empty batches
        if not any(task_mask):
            print(f"\u26a0\ufe0f Warning: Empty batch at step {step}, skipping")
            optimizer.zero_grad()
            continue

        # V11.6: DB-MTL loss with EMA + clamp (all handled inside loss_fn)
        loss = loss_fn(losses, shared_rep, task_mask=task_mask) / CONFIG['gradient_accumulation']
        
        # Track effective weights for visualization
        step_weights.append(loss_fn.get_weights())
        
        # NaN/Inf safety check
        # Note: log-losses are negative (log of values < 1), this is expected
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\u26a0\ufe0f Warning: Invalid loss {loss.item():.4f} at step {step}, skipping batch")
            optimizer.zero_grad()
            continue

        loss.backward()
        
        # Gradient Accumulation
        is_accum_step = (step + 1) % CONFIG['gradient_accumulation'] == 0
        if is_accum_step:
            nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
        total_loss += loss.item() * CONFIG['gradient_accumulation']
        if step % 50 == 0: pbar.set_postfix({'loss': f'{loss.item() * CONFIG["gradient_accumulation"]:.3f}'})
    
    # Average effective weights across all steps this epoch
    avg_weights = np.mean(step_weights, axis=0) if step_weights else np.ones(4)
    return total_loss / len(train_loader), avg_weights

@torch.no_grad()
def evaluate_toxicity():
    """Evaluate toxicity on held-out OLID test set."""
    model.eval()
    preds, trues = [], []
    for batch in tox_val_loader:
        out = model(batch['ids'].to(device), batch['mask'].to(device))
        preds.extend(out['toxicity'].argmax(1).cpu().numpy())
        trues.extend(batch['tox'].numpy())
    return f1_score(trues, preds, average='macro', zero_division=0)

@torch.no_grad()
def evaluate_emotion():
    """Evaluate emotion on held-out GoEmotions dev set (V11 FIX #2)."""
    model.eval()
    all_preds, all_trues = [], []
    for batch in emo_val_loader:
        out = model(batch['ids'].to(device), batch['mask'].to(device))
        pred = (torch.sigmoid(out['emotion']) > 0.5).cpu().numpy()
        all_preds.append(pred)
        all_trues.append(batch['emo'].numpy())
    all_preds = np.concatenate(all_preds)
    all_trues = np.concatenate(all_trues)
    # Per-emotion F1, then average
    f1s = []
    for i in range(len(EMO_COLS)):
        f1s.append(f1_score(all_trues[:, i], all_preds[:, i], average='binary', zero_division=0))
    return np.mean(f1s)

@torch.no_grad()
def evaluate_sentiment():
    """Evaluate sentiment on held-out SST-2 dev set (V11 FIX #2)."""
    model.eval()
    preds, trues = [], []
    for batch in sent_val_loader:
        out = model(batch['ids'].to(device), batch['mask'].to(device))
        preds.extend(out['sentiment'].argmax(1).cpu().numpy())
        trues.extend(batch['sent'].numpy())
    return f1_score(trues, preds, average='macro', zero_division=0)

print('\U0001f3af Training & evaluation functions defined.')
print('   V11.6: DB-MTL with scale clamping + EMA smoothing.')
print(f'   Scale clamp: {CONFIG["db_scale_clamp"]}, EMA beta: {CONFIG["db_ema_beta"]}')
print('   V11 FIX #3: task_mask for absent tasks (retained).')
print('   V11 FIX #4: Optimizer states reset on RoBERTa unfreeze (retained).')


In [None]:
# Cell 12: Main Training Loop
# V11.6: DB-MTL effective weights tracked per epoch (clamped + EMA).

print('='*60)
print('\U0001f680 AURA V11.6 \u2014 TRAINING START (DB-MTL + Clamp + EMA)')
print('='*60)

best_f1 = 0
patience_counter = 0
history = {
    'train_loss': [], 
    'val_f1': [],            # Toxicity (primary)
    'val_emo_f1': [],        # Emotion (V11)
    'val_sent_f1': [],       # Sentiment (V11)
    'task_weights': []       # DB-MTL effective weights (clamped + EMA)
}

for epoch in range(1, CONFIG['epochs'] + 1):
    train_loss, epoch_weights = train_epoch(epoch)
    val_f1 = evaluate_toxicity()
    emo_f1 = evaluate_emotion()
    sent_f1 = evaluate_sentiment()
    
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_emo_f1'].append(emo_f1)
    history['val_sent_f1'].append(sent_f1)
    history['task_weights'].append(epoch_weights.tolist())
    
    print(f'\nEpoch {epoch} Summary:')
    print(f'  Train Loss:     {train_loss:.4f}')
    print(f'  Toxicity Val F1:  {val_f1:.4f}')
    print(f'  Emotion Val F1:   {emo_f1:.4f}')
    print(f'  Sentiment Val F1: {sent_f1:.4f}')
    print(f'  DB-MTL Weights [Tox/Emo/Sent/Rep]: {np.array(epoch_weights).round(3)}')
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        best_emo_f1 = emo_f1
        best_sent_f1 = sent_f1
        patience_counter = 0
        torch.save(model.state_dict(), 'aura_v11.6_best.pt')
        print('  >>> BEST MODEL SAVED <<<')
    else:
        patience_counter += 1
        print(f'  (No improvement. Patience: {patience_counter}/{CONFIG["patience"]})')
        if patience_counter >= CONFIG['patience']:
            print(f'\n\u26a0\ufe0f Early stopping at epoch {epoch}')
            break

print('\n' + '='*60)
print(f'\u2705 Training Complete.')
print(f'   Best Toxicity F1:  {best_f1:.4f}')
print(f'   Best Emotion F1:   {best_emo_f1:.4f}')
print(f'   Best Sentiment F1: {best_sent_f1:.4f}')
print('='*60)


In [None]:
# Cell 13: Training History Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(range(1, len(history['train_loss'])+1), history['train_loss'], 'b-o', label='Train')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Toxicity F1
axes[0, 1].plot(range(1, len(history['val_f1'])+1), history['val_f1'], 'g-o', label='Toxicity Val F1')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Macro F1')
axes[0, 1].set_title('Toxicity Validation F1')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Emotion + Sentiment F1
axes[1, 0].plot(range(1, len(history['val_emo_f1'])+1), history['val_emo_f1'], 'r-o', label='Emotion Val F1')
axes[1, 0].plot(range(1, len(history['val_sent_f1'])+1), history['val_sent_f1'], 'm-o', label='Sentiment Val F1')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1')
axes[1, 0].set_title('Auxiliary Task Validation F1')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# DB-MTL Effective Weights (clamped + EMA)
weights = np.array(history['task_weights'])
for i, name in enumerate(['Toxicity', 'Emotion', 'Sentiment', 'Reporting']):
    axes[1, 1].plot(range(1, len(weights)+1), weights[:, i], '-o', label=name)
axes[1, 1].axhline(y=CONFIG['db_scale_clamp'], color='red', linestyle='--', alpha=0.5, label='Clamp limit')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Effective Weight')
axes[1, 1].set_title('DB-MTL Effective Weights (clamped + EMA)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
# Cell 14: Final Evaluation — Toxicity
print('='*60)
print('\U0001f52c FINAL EVALUATION: TOXICITY')
print('='*60)

model.load_state_dict(torch.load('aura_v11.6_best.pt'))
model.eval()

preds, trues = [], []
with torch.no_grad():
    for batch in tox_val_loader:
        out = model(batch['ids'].to(device), batch['mask'].to(device))
        preds.extend(out['toxicity'].argmax(1).cpu().numpy())
        trues.extend(batch['tox'].numpy())

print('\n--- Classification Report ---')
print(classification_report(trues, preds, target_names=['Non-Toxic', 'Toxic']))

fig, ax = plt.subplots(figsize=(6, 5))
plot_confusion_matrix_heatmap(trues, preds, ['Non-Toxic', 'Toxic'], 'Toxicity Confusion Matrix', ax)
plt.tight_layout()
plt.show()

In [None]:
# Cell 15: Final Evaluation — Emotion
# V11 FIX #2: Uses official GoEmotions dev split instead of training data tail.
# This eliminates the data leak present in V10.2.

print('='*60)
print('\U0001f52c FINAL EVALUATION: EMOTION (GoEmotions Dev)')
print('='*60)

emo_preds, emo_trues = [], []
model.eval()
with torch.no_grad():
    for batch in tqdm(emo_val_loader, desc='Evaluating Emotions'):
        out = model(batch['ids'].to(device), batch['mask'].to(device))
        pred = (torch.sigmoid(out['emotion']) > 0.5).cpu().numpy()
        emo_preds.append(pred)
        emo_trues.append(batch['emo'].numpy())

emo_preds = np.concatenate(emo_preds)
emo_trues = np.concatenate(emo_trues)

# Per-emotion metrics
print('\n--- Per-Emotion Metrics ---')
for i, emo in enumerate(EMO_COLS):
    p, r, f1, _ = precision_recall_fscore_support(emo_trues[:, i], emo_preds[:, i], average='binary', zero_division=0)
    print(f'{emo:10s}: P={p:.3f}, R={r:.3f}, F1={f1:.3f}')

# Multilabel Confusion Matrices
print('\n--- Multilabel Confusion Matrices ---')
plot_multilabel_confusion_matrices(emo_trues, emo_preds, EMO_COLS)

In [None]:
# Cell 16: Final Evaluation — Sentiment (V11 NEW)
# Uses official SST-2 dev split for evaluation.

print('='*60)
print('\U0001f52c FINAL EVALUATION: SENTIMENT (SST-2 Dev)')
print('='*60)

sent_preds, sent_trues = [], []
model.eval()
with torch.no_grad():
    for batch in tqdm(sent_val_loader, desc='Evaluating Sentiment'):
        out = model(batch['ids'].to(device), batch['mask'].to(device))
        sent_preds.extend(out['sentiment'].argmax(1).cpu().numpy())
        sent_trues.extend(batch['sent'].numpy())

print('\n--- Classification Report ---')
print(classification_report(sent_trues, sent_preds, target_names=['Negative', 'Positive']))

fig, ax = plt.subplots(figsize=(6, 5))
plot_confusion_matrix_heatmap(sent_trues, sent_preds, ['Negative', 'Positive'], 'Sentiment Confusion Matrix', ax)
plt.tight_layout()
plt.show()

In [None]:
# Cell 17: Qualitative Stress Test — IDENTICAL TO V10.2
print('='*60)
print('\U0001f9ea QUALITATIVE STRESS TEST')
print('='*60)

test_cases = [
    ("I hate rainy Mondays", "Non-Toxic"),
    ("This soup is disgusting", "Non-Toxic"),
    ("I am so angry at the traffic", "Non-Toxic"),
    ("This movie was terrible", "Non-Toxic"),
    
    ("You are an idiot", "Toxic"),
    ("I hate you so much", "Toxic"),
    ("Go kill yourself", "Toxic"),
    ("You are worthless garbage", "Toxic"),
    
    ("He said you are an idiot", "Non-Toxic"),
    ("The article discusses hate speech", "Non-Toxic"),
    ("Someone wrote 'go die' in the comments", "Non-Toxic"),
]

print(f"{'Text':<50} {'Expected':<12} {'Predicted':<12} {'Status'}")
print('-'*80)

correct = 0
model.eval()
with torch.no_grad():
    for text, expected in test_cases:
        enc = tokenizer(text, max_length=128, padding='max_length', truncation=True, return_tensors='pt')
        out = model(enc['input_ids'].to(device), enc['attention_mask'].to(device))
        pred_idx = out['toxicity'].argmax(1).item()
        pred_label = 'Toxic' if pred_idx == 1 else 'Non-Toxic'
        status = '\u2705' if pred_label == expected else '\u274c'
        if pred_label == expected:
            correct += 1
        print(f"{text[:48]:<50} {expected:<12} {pred_label:<12} {status}")

print('-'*80)
print(f'Stress Test Accuracy: {correct}/{len(test_cases)} ({correct/len(test_cases)*100:.0f}%)')

In [None]:
# Cell 18: V11.6 Summary & Comparison
print('='*60)
print('\u2697\ufe0f V11.6 RESULTS SUMMARY')
print('='*60)

# Known previous results
V10_TOX_F1 = 0.7572
V11_TOX_F1 = 0.7418
V113_TOX_F1 = 0.7830
V114_TOX_F1 = 0.7836
V115_TOX_F1 = 0.7758

V11_EMO_F1 = 0.6202
V11_SENT_F1 = 0.9403
V113_EMO_F1 = 0.6112
V113_SENT_F1 = 0.9334
V114_EMO_F1 = 0.5969
V114_SENT_F1 = 0.9426
V115_EMO_F1 = 0.6178
V115_SENT_F1 = 0.9484

V116_TOX_F1 = best_f1

print(f'\n{"Metric":<25} {"V11.3":<10} {"V11.4":<10} {"V11.5":<10} {"V11.6":<10}')
print('-'*65)
print(f'{"Toxicity Val F1":<25} {V113_TOX_F1:<10.4f} {V114_TOX_F1:<10.4f} {V115_TOX_F1:<10.4f} {V116_TOX_F1:<10.4f}')
print(f'{"Emotion Val F1":<25} {V113_EMO_F1:<10.4f} {V114_EMO_F1:<10.4f} {V115_EMO_F1:<10.4f} {best_emo_f1:<10.4f}')
print(f'{"Sentiment Val F1":<25} {V113_SENT_F1:<10.4f} {V114_SENT_F1:<10.4f} {V115_SENT_F1:<10.4f} {best_sent_f1:<10.4f}')
print(f'{"Gradient Balance":<25} {"GradNorm":<10} {"GradNorm":<10} {"DB-MTL":<10} {"DB-MTL":<10}')
print(f'{"Scale Clamp":<25} {"No":<10} {"No":<10} {"No":<10} {"5.0":<10}')
print(f'{"EMA Smoothing":<25} {"No":<10} {"No":<10} {"No":<10} {"0.9":<10}')
print(f'{"Learnable Weights":<25} {"Yes":<10} {"Yes":<10} {"No":<10} {"No":<10}')
print('-'*65)

delta_v115 = V116_TOX_F1 - V115_TOX_F1
delta_v114 = V116_TOX_F1 - V114_TOX_F1
delta_v10 = V116_TOX_F1 - V10_TOX_F1
print(f'\n\u0394 vs V11.5 (DB-MTL unclamped): {delta_v115:+.4f} F1')
print(f'\u0394 vs V11.4 (GradNorm \u03b1=1.0):  {delta_v114:+.4f} F1')
print(f'\u0394 vs V10.2 (Kendall):          {delta_v10:+.4f} F1')

print('\n--- V11.6 Changes ---')
print('  V11 fixes retained:')
print('    1. \u2705 Proper emotion/sentiment validation')
print('    2. \u2705 Emotion eval data leak fixed')
print('    3. \u2705 Task mask in multi-task loss')
print('    4. \u2705 Optimizer state reset on unfreeze')
print('  V11.5 retained:')
print('    5. \u2705 DB-MTL gradient balancing (replaces GradNorm)')
print('    6. \u2705 No learnable weights (zero divergence risk)')
print('  V11.6 new:')
print(f'    7. \u2705 Scale clamp at {CONFIG["db_scale_clamp"]}')
print(f'    8. \u2705 EMA-smoothed gradient norms (beta={CONFIG["db_ema_beta"]})')


In [None]:
# Cell 19: Save Artifacts
print('='*60)
print('\U0001f4be SAVING V11.6 ARTIFACTS')
print('='*60)

import json as json_save
history_serializable = {
    'train_loss': history['train_loss'],
    'val_f1': history['val_f1'],
    'val_emo_f1': history['val_emo_f1'],
    'val_sent_f1': history['val_sent_f1'],
    'task_weights': history['task_weights'],
    'best_f1': best_f1,
    'best_emo_f1': best_emo_f1,
    'best_sent_f1': best_sent_f1,
    'config': CONFIG,
    'model_type': 'aura_v11.6_dbmtl_clamped_ema',
    'balancing_method': 'DB-MTL + Clamp + EMA (Lin et al., 2023)',
    'v11_fixes': [
        'proper_emotion_sentiment_validation',
        'task_mask_in_loss',
        'optimizer_state_reset_on_unfreeze',
        'emotion_eval_data_leak_fixed'
    ],
    'v116_improvements': [
        'db_mtl_gradient_balancing',
        'no_learnable_task_weights',
        'scale_clamping',
        'ema_smoothed_gradient_norms'
    ]
}
with open('aura_v11.6_history.json', 'w') as f:
    json_save.dump(history_serializable, f, indent=2)

print('\u2705 Model saved: aura_v11.6_best.pt')
print('\u2705 History saved: aura_v11.6_history.json')
print(f'\n\U0001f3c6 Best Toxicity F1:  {best_f1:.4f}')
print(f'\U0001f3c6 Best Emotion F1:   {best_emo_f1:.4f}')
print(f'\U0001f3c6 Best Sentiment F1: {best_sent_f1:.4f}')
