# AURA V4: BERT + Clean Data + Focal Loss

---
## PRIMA DI ESEGUIRE:
1. **Settings** -> **Accelerator** -> **GPU T4 x2**
2. **Add Input** -> Carica `aura-data-v2`
---

### V4 Features
| Component | Implementation |
|-----------|----------------|
| Backbone | BERT-base |
| Data | Clean (7 emotion classes) |
| Loss (Toxicity) | **Focal Loss** (γ=2.0) |
| Loss (Emotions) | BCE |
| MTL Balancing | Kendall Uncertainty |

**Theoretical Advantage**: Focal Loss dynamically down-weights easy examples, focusing on hard negatives.

## 1. Setup & Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.optim.lr_scheduler import OneCycleLR
from transformers import BertModel, BertTokenizer
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, classification_report
import pandas as pd
import numpy as np
import os
import warnings
warnings.filterwarnings('ignore')

print("="*50)
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    raise RuntimeError("ATTIVA LA GPU!")
print("="*50)

torch.manual_seed(42)
np.random.seed(42)

## 2. Configuration

In [None]:
CONFIG = {
    'encoder': 'bert-base-uncased',
    'max_length': 128,
    'num_emotion_classes': 7,
    'dropout': 0.1,
    'batch_size': 16,
    'gradient_accumulation': 2,
    'epochs': 5,
    'lr': 2e-5,
    'weight_decay': 0.01,
    'patience': 2,
    'mc_samples': 10,
    'focal_gamma': 2.0,  # Ablation: Test [1.0, 2.0, 3.0]
    'output_dir': '/kaggle/working'
}

# V4: Using Focal Loss instead of Class Weights
# Focal Loss (gamma=2.0) automatically focuses on hard examples
print(f"V4: Focal Loss enabled (gamma=2.0)")

# Data paths - NOTA: usa goemotions_clean.csv (dati puliti)
DATA_DIR = None
for path in ['/kaggle/input/aura-data-v2', '/kaggle/input/aura_data_v2', '/kaggle/input/aura-data', 'data/processed', 'data/kaggle_upload_v2']:
    if os.path.exists(path):
        # Cerca il file pulito
        if os.path.exists(os.path.join(path, 'goemotions_clean.csv')):
            DATA_DIR = path
            GOEMO_FILE = 'goemotions_clean.csv'
            break
        elif os.path.exists(os.path.join(path, 'goemotions_processed.csv')):
            DATA_DIR = path
            GOEMO_FILE = 'goemotions_processed.csv'
            print("WARNING: Using old goemotions_processed.csv!")
            break

if DATA_DIR is None:
    raise FileNotFoundError("Dataset non trovato!")

print(f"Dataset: {DATA_DIR}")
print(f"GoEmotions file: {GOEMO_FILE}")

## 3. Model: AURA Bayesian (BERT)

In [None]:
class AURA_Bayesian(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert = BertModel.from_pretrained(config['encoder'])
        hidden_size = self.bert.config.hidden_size
        self.dropout = nn.Dropout(config['dropout'])
        
        self.toxicity_head = nn.Linear(hidden_size, 2)
        self.emotion_head = nn.Linear(hidden_size, config['num_emotion_classes'])
        
        # Homoscedastic Uncertainty (Kendall 2018)
        self.tox_log_var = nn.Parameter(torch.zeros(1))
        self.emo_log_var = nn.Parameter(torch.zeros(1))
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        
        tox_logits = self.toxicity_head(pooled)
        emo_logits = self.emotion_head(pooled)
        
        return tox_logits, emo_logits, self.tox_log_var, self.emo_log_var

## 4. Loss Functions (with Focal Loss)

In [None]:
def focal_loss_with_uncertainty(logits, log_var, targets, gamma=2.0, T=10):
    """
    Focal Loss integrated with Kendall's Uncertainty (Lin et al., 2017 + Kendall et al., 2018).
    
    Focal Loss: FL(p_t) = -(1 - p_t)^gamma * log(p_t)
    where p_t is the probability of the correct class.
    
    Args:
        logits: Model predictions [batch, num_classes]
        log_var: Log-variance parameter (Kendall)
        targets: Ground truth labels [batch]
        gamma: Focal loss focusing parameter (default: 2.0)
        T: Monte Carlo samples
    """
    log_var_clamped = torch.clamp(log_var, min=-10, max=10)
    std = torch.exp(0.5 * log_var_clamped)
    
    # Monte Carlo Sampling
    logits_expanded = logits.unsqueeze(0).expand(T, -1, -1)
    noise = torch.randn_like(logits_expanded)
    corrupted_logits = logits_expanded + (noise * std)
    
    # Average probabilities
    probs = F.softmax(corrupted_logits, dim=-1)
    avg_probs = torch.mean(probs, dim=0)
    
    # Get probabilities of correct class (p_t)
    p_t = avg_probs[range(len(targets)), targets]
    
    # Focal Loss formula: -(1 - p_t)^gamma * log(p_t)
    focal_weight = (1 - p_t) ** gamma
    ce_loss = -torch.log(p_t + 1e-8)
    focal_loss = (focal_weight * ce_loss).mean()
    
    # Kendall regularization
    regularization = 0.5 * log_var_clamped
    
    return focal_loss + regularization


def monte_carlo_uncertainty_loss_multilabel(logits, log_var, targets, T=10):
    """
    Bayesian Uncertainty Loss for Multi-Label (Emotions).
    """
    log_var_clamped = torch.clamp(log_var, min=-10, max=10)
    std = torch.exp(0.5 * log_var_clamped)
    
    logits_expanded = logits.unsqueeze(0).expand(T, -1, -1)
    noise = torch.randn_like(logits_expanded)
    corrupted_logits = logits_expanded + (noise * std)
    
    probs = torch.sigmoid(corrupted_logits)
    avg_probs = torch.mean(probs, dim=0)
    
    bce = F.binary_cross_entropy(avg_probs, targets, reduction='mean')
    regularization = 0.5 * log_var_clamped
    
    return bce + regularization

## 5. Dataset Class

In [None]:
class AURADataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length, is_toxicity=True):
        self.df = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_toxicity = is_toxicity
        self.emo_cols = ['anger', 'disgust', 'fear', 'joy', 'sadness', 'surprise', 'neutral']
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row.get('text', row.get('tweet', '')))
        
        enc = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        tox_label = -1
        emo_label = torch.full((7,), -1.0)
        
        if self.is_toxicity:
            label_raw = row['label'] if 'label' in row else row.get('subtask_a', 'NOT')
            tox_label = 1 if label_raw in [1, 'OFF'] else 0
        else:
            emo_label = torch.tensor([float(row[c]) for c in self.emo_cols], dtype=torch.float32)
        
        return {
            'input_ids': enc['input_ids'].flatten(),
            'attention_mask': enc['attention_mask'].flatten(),
            'toxicity_target': torch.tensor(tox_label, dtype=torch.long),
            'emotion_target': emo_label,
            'is_toxicity_task': torch.tensor(1 if self.is_toxicity else 0, dtype=torch.long)
        }

## 6. Data Loading

In [None]:
tokenizer = BertTokenizer.from_pretrained(CONFIG['encoder'])

# Load datasets
olid_train = AURADataset(f"{DATA_DIR}/olid_train.csv", tokenizer, CONFIG['max_length'], is_toxicity=True)
olid_val = AURADataset(f"{DATA_DIR}/olid_validation.csv", tokenizer, CONFIG['max_length'], is_toxicity=True)
goemo_full = AURADataset(f"{DATA_DIR}/{GOEMO_FILE}", tokenizer, CONFIG['max_length'], is_toxicity=False)

# Verify GoEmotions is the clean version
goemo_df = pd.read_csv(f"{DATA_DIR}/{GOEMO_FILE}")
disgust_count = goemo_df['disgust'].sum() if 'disgust' in goemo_df.columns else 0
neutral_count = goemo_df['neutral'].sum() if 'neutral' in goemo_df.columns else 0
print(f"GoEmotions verification:")
print(f"  Disgust samples: {disgust_count} (should be >0)")
print(f"  Neutral samples: {neutral_count} (should be >0)")

if disgust_count == 0 or neutral_count == 0:
    print("  WARNING: Using broken GoEmotions data!")
else:
    print("  OK: Clean data confirmed!")

# Sample GoEmotions (use all clean data, it's smaller now)
goemo_indices = np.random.choice(len(goemo_full), min(40000, len(goemo_full)), replace=False)
goemo_subset = torch.utils.data.Subset(goemo_full, goemo_indices)

# Combine
train_set = ConcatDataset([olid_train, goemo_subset])
train_loader = DataLoader(train_set, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(olid_val, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

print(f"\nDataset sizes:")
print(f"  Training: {len(train_set)} (OLID: {len(olid_train)}, GoEmo: {len(goemo_subset)})")
print(f"  Validation: {len(olid_val)}")

## 7. Training Function

In [None]:
def train_epoch(model, loader, optimizer, scheduler, epoch, config):
    model.train()
    total_loss = 0
    tox_preds, tox_labels = [], []
    
    loop = tqdm(loader, desc=f"Epoch {epoch}", leave=True)
    optimizer.zero_grad()
    
    for step, batch in enumerate(loop):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tox_targets = batch['toxicity_target'].to(device)
        emo_targets = batch['emotion_target'].to(device)
        is_tox_task = batch['is_toxicity_task'].to(device)
        
        tox_logits, emo_logits, tox_log_var, emo_log_var = model(input_ids, attention_mask)
        
        loss = torch.tensor(0.0, device=device)
        
        # Toxicity Loss (with class weights)
        tox_mask = is_tox_task == 1
        if tox_mask.sum() > 0:
            tox_loss = focal_loss_with_uncertainty(
                tox_logits[tox_mask], 
                tox_log_var, 
                tox_targets[tox_mask],
                gamma=config['focal_gamma'],
                T=config['mc_samples']
            )
            loss = loss + tox_loss
            
            preds = torch.argmax(tox_logits[tox_mask], dim=1).cpu().numpy()
            tox_preds.extend(preds)
            tox_labels.extend(tox_targets[tox_mask].cpu().numpy())
        
        # Emotion Loss
        emo_mask = is_tox_task == 0
        if emo_mask.sum() > 0:
            emo_loss = monte_carlo_uncertainty_loss_multilabel(
                emo_logits[emo_mask], 
                emo_log_var, 
                emo_targets[emo_mask],
                T=config['mc_samples']
            )
            loss = loss + emo_loss
        
        loss = loss / config['gradient_accumulation']
        loss.backward()
        
        if (step + 1) % config['gradient_accumulation'] == 0:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * config['gradient_accumulation']
        
        sigma_tox = torch.exp(0.5 * tox_log_var).item()
        sigma_emo = torch.exp(0.5 * emo_log_var).item()
        loop.set_postfix(loss=loss.item(), s_tox=f"{sigma_tox:.3f}", s_emo=f"{sigma_emo:.3f}")
    
    avg_loss = total_loss / len(loader)
    train_f1 = f1_score(tox_labels, tox_preds, average='macro') if tox_labels else 0
    
    return avg_loss, train_f1

## 8. Validation Function

In [None]:
@torch.no_grad()
def validate(model, loader):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    
    for batch in tqdm(loader, desc="Validating", leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tox_targets = batch['toxicity_target'].to(device)
        
        tox_logits, _, _, _ = model(input_ids, attention_mask)
        
        loss = F.cross_entropy(tox_logits, tox_targets)
        total_loss += loss.item()
        
        preds = torch.argmax(tox_logits, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(tox_targets.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    val_f1 = f1_score(all_labels, all_preds, average='macro')
    
    return avg_loss, val_f1, all_preds, all_labels

## 9. Main Training Loop

In [None]:
model = AURA_Bayesian(CONFIG).to(device)
print(f"Model: BERT (110M params) on {device}")

optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])

total_steps = len(train_loader) * CONFIG['epochs'] // CONFIG['gradient_accumulation']
scheduler = OneCycleLR(optimizer, max_lr=CONFIG['lr'], total_steps=total_steps, pct_start=0.1)

best_f1 = 0
patience_counter = 0

print("\n" + "="*60)
print("STARTING V4 TRAINING (Clean Data + Focal Loss)")
print("="*60)

for epoch in range(1, CONFIG['epochs'] + 1):
    print(f"\nEpoch {epoch}/{CONFIG['epochs']}")
    
    train_loss, train_f1 = train_epoch(model, train_loader, optimizer, scheduler, epoch, CONFIG)
    val_loss, val_f1, preds, labels = validate(model, val_loader)
    
    sigma_tox = torch.exp(0.5 * model.tox_log_var).item()
    sigma_emo = torch.exp(0.5 * model.emo_log_var).item()
    gap = abs(train_f1 - val_f1) * 100
    
    print(f"   Train Loss: {train_loss:.4f} | Train F1: {train_f1:.4f}")
    print(f"   Val Loss:   {val_loss:.4f} | Val F1:   {val_f1:.4f}")
    print(f"   Gap: {gap:.1f}% | sigma_Tox: {sigma_tox:.4f} | sigma_Emo: {sigma_emo:.4f}")
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        torch.save(model.state_dict(), f"{CONFIG['output_dir']}/aura_v4_focal_best.pt")
        print(f"   NEW BEST! (F1: {best_f1:.4f})")
    else:
        patience_counter += 1
        print(f"   No improvement ({patience_counter}/{CONFIG['patience']})")
    
    if patience_counter >= CONFIG['patience']:
        print(f"\nEarly stopping at epoch {epoch}")
        break

print("\n" + "="*60)
print(f"TRAINING COMPLETE | Best Val F1: {best_f1:.4f}")
print("="*60)

## 10. Final Evaluation

In [None]:
model.load_state_dict(torch.load(f"{CONFIG['output_dir']}/aura_v4_focal_best.pt"))
val_loss, val_f1, preds, labels = validate(model, val_loader)

print("\nFINAL CLASSIFICATION REPORT (V4 - Clean Data + Focal Loss)")
print("="*50)
print(classification_report(labels, preds, target_names=['NOT', 'OFF']))

print(f"\nFinal Macro-F1: {val_f1:.4f}")
print(f"Model saved: {CONFIG['output_dir']}/aura_v4_focal_best.pt")