# AURA V8: 4-Task Engine (Kaggle Edition)

---
## Instructions
1. **Runtime** -> Make sure you have GPU enabled
2. **Add Data** -> Search and add your dataset ura-mega-data
3. The CSV files should be directly in /kaggle/input/aura-v8-data/
4. Run all cells below
---


In [None]:
# Kaggle Environment already has transformers installed
# dataset is mounted at /kaggle/input/aura-v8-data

In [None]:
# 2. Setup & Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.optim.lr_scheduler import OneCycleLR
from transformers import BertModel, BertTokenizer
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, classification_report
import pandas as pd
import numpy as np
import os
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

torch.manual_seed(42)
np.random.seed(42)

In [None]:
# 3. Configuration
CONFIG = {
    'encoder': 'bert-base-uncased',
    'max_length': 128,
    'num_emotion_classes': 5,
    'dropout': 0.4,
    'batch_size': 32,
    'gradient_accumulation': 2,
    'epochs': 4,
    'lr': 5e-6,
    'weight_decay': 0.03,
    'patience': 3,
    'mc_samples': 10,
    'focal_gamma': 2.0,
    'output_dir': '.'
}

EMO_COLS = ['anger', 'disgust', 'fear', 'joy', 'neutral']
DATA_DIR = '/kaggle/input/aura-v8-data'  # Unzipped folder

In [None]:
# 4. Model Class (4-Head)
class AURA_MultiTask(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert = BertModel.from_pretrained(config['encoder'])
        hidden = self.bert.config.hidden_size
        self.dropout = nn.Dropout(config['dropout'])
        
        self.toxicity_head = nn.Linear(hidden, 2)
        self.emotion_head = nn.Linear(hidden, config['num_emotion_classes'])
        self.sentiment_head = nn.Linear(hidden, 2)
        self.hate_head = nn.Linear(hidden, 2)
        
        self.tox_log_var = nn.Parameter(torch.zeros(1))
        self.emo_log_var = nn.Parameter(torch.zeros(1))
        self.sent_log_var = nn.Parameter(torch.zeros(1))
        self.hate_log_var = nn.Parameter(torch.zeros(1))
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        
        return {
            'toxicity': self.toxicity_head(pooled),
            'emotion': self.emotion_head(pooled),
            'sentiment': self.sentiment_head(pooled),
            'hate': self.hate_head(pooled),
            'log_vars': {
                'toxicity': self.tox_log_var,
                'emotion': self.emo_log_var,
                'sentiment': self.sent_log_var,
                'hate': self.hate_log_var
            }
        }

In [None]:
# 5. Loss Functions
def focal_loss_with_uncertainty(logits, log_var, targets, gamma=2.0, T=10, label_smoothing=0.1):
    log_var = torch.clamp(log_var, -10, 10).squeeze()  # Squeeze to scalar
    std = torch.exp(0.5 * log_var)
    
    # Label Smoothing: soften targets to prevent overconfidence
    # This is equivalent to mixing with uniform distribution
    logits_exp = logits.unsqueeze(0).expand(T, -1, -1)
    noise = torch.randn_like(logits_exp)
    corrupted = logits_exp + noise * std
    probs = F.softmax(corrupted, dim=-1)
    avg_probs = probs.mean(dim=0)
    p_t = avg_probs[range(len(targets)), targets]
    focal_weight = (1 - p_t) ** gamma
    loss = (focal_weight * (-torch.log(p_t + 1e-8))).mean()
    return loss + 0.5 * log_var

def mc_bce_loss(logits, log_var, targets, T=10):
    log_var = torch.clamp(log_var, -10, 10).squeeze()  # Squeeze to scalar
    std = torch.exp(0.5 * log_var)
    
    # Label Smoothing: soften targets to prevent overconfidence
    # This is equivalent to mixing with uniform distribution
    logits_exp = logits.unsqueeze(0).expand(T, -1, -1)
    noise = torch.randn_like(logits_exp)
    corrupted = logits_exp + noise * std
    probs = torch.sigmoid(corrupted)
    avg_probs = probs.mean(dim=0)
    return F.binary_cross_entropy(avg_probs, targets, reduction='mean') + 0.5 * log_var


In [None]:
# 6. Data Loading with Custom Collate
def custom_collate(batch):
    """Handle mixed label shapes (scalar for binary, [5] for emotions)"""
    input_ids = torch.stack([x['input_ids'] for x in batch])
    attention_mask = torch.stack([x['attention_mask'] for x in batch])
    tasks = [x['task'] for x in batch]
    
    # Pad labels to same size (max 5 for emotions)
    labels = []
    for x in batch:
        lbl = x['label']
        if lbl.dim() == 0:  # scalar -> pad to [5]
            padded = torch.zeros(5)
            padded[0] = lbl.item()
            labels.append(padded)
        else:
            labels.append(lbl)
    labels = torch.stack(labels)
    
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label': labels, 'task': tasks}

class TaskDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_len, task_type, emo_cols=None):
        self.df = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.task_type = task_type
        self.emo_cols = emo_cols or EMO_COLS
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row['text'])
        enc = self.tokenizer.encode_plus(text, max_length=self.max_len, padding='max_length', truncation=True, return_tensors='pt')
        item = {'input_ids': enc['input_ids'].flatten(), 'attention_mask': enc['attention_mask'].flatten(), 'task': self.task_type}
        if self.task_type == 'emotion':
            item['label'] = torch.tensor([float(row[c]) for c in self.emo_cols])
        else:
            item['label'] = torch.tensor(int(row['label']))
        return item

tokenizer = BertTokenizer.from_pretrained(CONFIG['encoder'])
tox_train = TaskDataset(f'{DATA_DIR}/toxicity_train.csv', tokenizer, CONFIG['max_length'], 'toxicity')
emo_train = TaskDataset(f'{DATA_DIR}/emotions_train.csv', tokenizer, CONFIG['max_length'], 'emotion')
sent_train = TaskDataset(f'{DATA_DIR}/sentiment_train.csv', tokenizer, CONFIG['max_length'], 'sentiment')
hate_train = TaskDataset(f'{DATA_DIR}/hate_train.csv', tokenizer, CONFIG['max_length'], 'hate')
tox_val = TaskDataset(f'{DATA_DIR}/toxicity_validation.csv', tokenizer, CONFIG['max_length'], 'toxicity')

# --- BALANCED TASK SAMPLING ---
# Cap sentiment (73k) to match other tasks (~15k each)
# This prevents sentiment gradients from dominating toxicity learning

from torch.utils.data import Subset
import random

def balance_dataset(ds, max_samples=15000):
    if len(ds) > max_samples:
        indices = random.sample(range(len(ds)), max_samples)
        return Subset(ds, indices)
    return ds

# Apply balancing (critical for preventing overfit)
tox_train_bal = balance_dataset(tox_train, 12000)  # Keep all toxicity
emo_train_bal = balance_dataset(emo_train, 15000)  # Keep most emotions
sent_train_bal = balance_dataset(sent_train, 15000)  # CAP sentiment!
hate_train_bal = balance_dataset(hate_train, 12000)  # Keep all hate

train_set = ConcatDataset([tox_train_bal, emo_train_bal, sent_train_bal, hate_train_bal])
print(f'Balanced Training Set: {len(train_set)} samples (was 113k)')
train_loader = DataLoader(train_set, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(tox_val, batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=custom_collate)

print(f'Total Samples: {len(train_set)}')
print(f'Train Loader Batches: {len(train_loader)}')


In [None]:
# 7. Training
def train_epoch(model, loader, optimizer, scheduler, config):
    model.train()
    total_loss = 0
    loop = tqdm(loader, desc='Training')
    optimizer.zero_grad()
    tox_preds, tox_labels = [], []
    
    for step, batch in enumerate(loop):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        tasks = batch['task']
        
        outputs = model(input_ids, attention_mask)
        loss = torch.tensor(0.0, device=device)
        
        for task in ['toxicity', 'sentiment', 'hate']:
            mask = [t == task for t in tasks]
            if sum(mask) > 0:
                task_logits = outputs[task][mask]
                task_labels = labels[mask][:, 0].long()  # Extract first element
                loss += focal_loss_with_uncertainty(task_logits, outputs['log_vars'][task], task_labels, config['focal_gamma'], config['mc_samples'])
                if task == 'toxicity':
                    tox_preds.extend(torch.argmax(task_logits, dim=1).cpu().numpy())
                    tox_labels.extend(task_labels.cpu().numpy())
        
        emo_mask = [t == 'emotion' for t in tasks]
        if sum(emo_mask) > 0:
            loss += mc_bce_loss(outputs['emotion'][emo_mask], outputs['log_vars']['emotion'], labels[emo_mask].float(), config['mc_samples'])
        
        loss = loss / config['gradient_accumulation']
        loss.backward()
        
        if (step + 1) % config['gradient_accumulation'] == 0:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * config['gradient_accumulation']
        loop.set_postfix(loss=loss.item())
        
    return total_loss / len(loader), f1_score(tox_labels, tox_preds, average='macro') if tox_labels else 0

@torch.no_grad()
def validate(model, loader):
    model.eval()
    preds, labels = [], []
    for batch in tqdm(loader, desc='Validating', leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tox_labels = batch['label'][:, 0].long().to(device)  # Extract first element
        outputs = model(input_ids, attention_mask)
        preds.extend(torch.argmax(outputs['toxicity'], dim=1).cpu().numpy())
        labels.extend(tox_labels.cpu().numpy())
    return f1_score(labels, preds, average='macro')

# --- MAIN LOOP ---
model = AURA_MultiTask(CONFIG).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
scheduler = OneCycleLR(optimizer, max_lr=CONFIG['lr'], total_steps=len(train_loader)*CONFIG['epochs']//CONFIG['gradient_accumulation'])

best_f1 = 0
no_improve_count = 0
PATIENCE = 2  # Stop if no improvement for 2 epochs
history = {'train_loss': [], 'train_f1': [], 'val_f1': []}
print('STARTING V8 TRAINING')
for epoch in range(1, CONFIG['epochs'] + 1):
    loss, train_f1 = train_epoch(model, train_loader, optimizer, scheduler, CONFIG)
    val_f1 = validate(model, val_loader)
    print(f'Epoch {epoch}: Train Loss={loss:.4f}, Train F1={train_f1:.4f}, Val F1={val_f1:.4f}')
    history['train_loss'].append(loss)
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)
    if val_f1 > best_f1:
        best_f1 = val_f1
        no_improve_count = 0
        torch.save(model.state_dict(), 'aura_v8_best.pt')
        print('  NEW BEST!')
    else:
        no_improve_count += 1
        if no_improve_count >= PATIENCE:
            print(f'Early stopping triggered at epoch {epoch}!')
            break
print(f'COMPLETE. Best: {best_f1:.4f}')


# Save history
import pickle
with open('history_v8.pkl', 'wb') as f:
    pickle.dump(history, f)

In [None]:
# 8. Visualization
import matplotlib.pyplot as plt
import seaborn as sns

def plot_training_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    plt.figure(figsize=(15, 5))
    
    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-o', label='Train Loss')
    plt.title('Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot F1
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['train_f1'], 'b-o', label='Train F1')
    plt.plot(epochs, history['val_f1'], 'r-s', label='Validation F1')
    plt.title('F1 Score')
    plt.xlabel('Epochs')
    plt.ylabel('F1')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('aura_v8_training_curves.png')
    plt.show()

if 'history' in locals():
    plot_training_history(history)
elif os.path.exists('history_v8.pkl'):
    with open('history_v8.pkl', 'rb') as f:
        history = pickle.load(f)
    plot_training_history(history)
else:
    print('No history found. Training must complete first.')