## 02 Training

### Load necessary libraries

In [None]:
%load_ext autoreload
%autoreload 2

# Import necessary libraries
import pandas as pd
import numpy as np
import sys
import os
from datetime import datetime

from torch import nn
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, DistilBertTokenizer
from tqdm.auto import tqdm # For a nice progress bar
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
from datasets import Dataset


# Adds the parent directory to the python path
sys.path.append(os.path.abspath("..")) 

from src.core.features import clean_aegis, clean_jailbreak, merge_data
from src.utils.logger import Logger

### Load data

In [None]:
#Load & clean Aegis data
splits = {'train': 'train.json', 'validation': 'validation.json', 'test': 'test.json'}
train = pd.read_json("hf://datasets/nvidia/Aegis-AI-Content-Safety-Dataset-2.0/" + splits["train"])
test = pd.read_json("hf://datasets/nvidia/Aegis-AI-Content-Safety-Dataset-2.0/" + splits["test"]) 

aegis_clean = clean_aegis(train, train)
print(f"Aegis data shape: {aegis_clean.shape}")

#Load & clean jailbreak data
jailbreak = pd.read_csv("hf://datasets/allenai/wildjailbreak/train/train.tsv", sep="\t")

jailbreak_clean = clean_jailbreak(jailbreak)
print(f"Jailbreak data shape: {jailbreak_clean.shape}")

# Merge, encode and tokenise
data = merge_data(jailbreak_clean, aegis_clean)

# Dictionary to map strings to numbers
id2label = {i: name for i, name in enumerate(data['label'].unique())}
label2id = {name: i for i, name in id2label.items()}

# Apply to your dataframe
data['label'] = data['label'].map(label2id)

data.shape

### Create data splits

In [None]:
train_val_df, test_df = train_test_split(
    data, 
    test_size=0.10, 
    stratify=data['label'], 
    random_state=42
)

train_df, val_df = train_test_split(
    train_val_df, 
    test_size=(1/9), 
    stratify=train_val_df['label'], 
    random_state=42
)

print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
print(train_df['label'].value_counts())

### Load necessary libraries

In [None]:
model_path = "google-bert/bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_path,
                                                           num_labels=6,
                                                           id2label=id2label,
                                                           label2id=label2id,)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# freeze all base model parameters
for name, param in model.base_model.named_parameters():
    param.requires_grad = False

# unfreeze base model pooling layers
for name, param in model.base_model.named_parameters():
    if "pooler" in name:
        param.requires_grad = True

### Tokenise

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["prompt"], padding="max_length", truncation=True, max_length=128)

train_ds = Dataset.from_pandas(train_df)
train_ds = train_ds.map(tokenize_function, batched=True)
train_ds.set_format("torch")

val_ds = Dataset.from_pandas(val_df)
val_ds = val_ds.map(tokenize_function, batched=True)
val_ds.set_format("torch")

test_ds = Dataset.from_pandas(test_df)
test_ds = test_ds.map(tokenize_function, batched=True)
test_ds.set_format("torch")

loss_fn = nn.CrossEntropyLoss()
lr = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

### Training loop function

In [None]:
train_loader = DataLoader(train_ds, batch_size = 16, shuffle = True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size = 16, shuffle = True, num_workers=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def train(num_epoch, logger, model_save_path):
    best_f1 = 0
    for epoch in range(num_epoch):
        print(f"\n--- Epoch {epoch + 1} ---")
        model.train()
        total_train_loss = 0
        train_preds = []
        train_labels = []

        for batch, data in enumerate(train_loader):
            input_ids = data['input_ids']
            attention_mask = data['attention_mask']
            labels = data['label']
            
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fn(logits, labels)
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            preds = torch.argmax(logits, dim=1)
            train_preds.extend(preds.cpu().detach().numpy())
            train_labels.extend(labels.cpu().detach().numpy())
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_f1 = f1_score(train_labels, train_preds, average='weighted')
        
        total_val_loss = 0
        correct_predictions = 0
        val_preds = []
        val_labels = []
        model.eval()
        
        with torch.no_grad():
            for data in val_loader:
                input_ids = data['input_ids']
                attention_mask = data['attention_mask']
                labels = data['label']
                
                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                loss = loss_fn(logits, labels)
                total_val_loss += loss.item()
            
                preds = torch.argmax(logits, dim=1)
                val_preds.extend(preds.cpu().detach().numpy())
                val_labels.extend(labels.cpu().detach().numpy())
                correct_predictions += torch.sum(preds == labels)

            avg_val_loss = total_val_loss / len(val_loader)
            val_acc = correct_predictions.double() / len(val_ds)
            val_f1 = f1_score(val_labels, val_preds, average='weighted')

        # LOGGING 
        epoch_metrics = {
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'train_f1': train_f1,
            'val_loss': avg_val_loss,
            'val_f1': val_f1,
            'accuracy': val_acc.item(),
            'lr': optimizer.param_groups[0]['lr']
        }
        logger.log(epoch_metrics)

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved. New best F1: {val_f1:.4f}")



### Training

In [None]:
training_num = 1
time_start = datetime.today().strftime('%Y-%m-%d_%H:%M:%S')
for i in range(training_num):
    logger = Logger(f'{time_start}_training{i}.csv')
    
    model_save_path = f'models/{time_start}_best_model_run_{i}.pt'

    train(3, logger, model_save_path)
    
    logger.visualise()
    best_stats = logger.end()


### Testing

In [None]:
test_loader = DataLoader(test_ds, batch_size = 16, shuffle = True, num_workers=2)

model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=6)
model.load_state_dict(torch.load('best_model_weights.pt'))
model.to(device)

model.eval()

test_preds = []
test_labels = []

with torch.no_grad():
    for batch in test_loader:
        # Move to device
        ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Get predictions
        outputs = model(ids, attention_mask=mask)
        preds = torch.argmax(outputs.logits, dim=1)
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

print(classification_report(test_labels, test_preds, target_names=labels))

### Visualise model performance

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# 1. Assuming you have 'test_labels' and 'test_preds' from your test loop
categories = [
    'vanilla_benign', 'vanilla_harmful', 
    'adversarial_benign', 'adversarial_harmful', 
    'safe', 'unsafe'
]

# 2. Generate the matrix
cm = confusion_matrix(test_labels, test_preds)

# 3. Create a clean Heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(
    cm, annot=True, fmt='d', cmap='Blues', 
    xticklabels=categories, yticklabels=categories
)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Clean-Talk Guardrail: Confusion Matrix')
plt.show()

# 4. Print the full report for precision/recall per class
print(classification_report(test_labels, test_preds, target_names=data['label'].unique()))