# ðŸš€ Complete Training Pipeline for Toxic Comment Classification
# Optimized for Google Colab GPU (15GB VRAM)

This notebook runs all models for Phases 1-3:
- Logistic Regression Baseline
- BiLSTM
- BiLSTM + Attention  
- DistilBERT (faster)
- BERT-base (best performance)

**Runtime**: ~2-3 hours on Colab GPU
**GPU Usage**: Optimized batch sizes and mixed precision training


## Setup: Clone Repo and Install Dependencies


In [None]:
# Clone your repo (replace with your actual repo URL)
!git clone https://github.com/Nak1106/Toxic_comment-classifier.git
%cd Toxic_comment-classifier

# Install requirements
!pip install -q torch transformers scikit-learn pandas matplotlib seaborn tqdm

# Check GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## Upload Dataset to Colab


In [None]:
# Option 1: Upload from local (run this cell and upload train.csv)
from google.colab import files
uploaded = files.upload()  # Upload your jigsaw train.csv
!mv train.csv data/jigsaw_train.csv

# Option 2: Download from Kaggle (if you have kaggle.json)
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle competitions download -c jigsaw-toxic-comment-classification-challenge
# !unzip jigsaw-toxic-comment-classification-challenge.zip -d data/
# !mv data/train.csv data/jigsaw_train.csv


## Imports and Configuration


In [None]:
import sys
import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm.auto import tqdm

# Import project modules
from src.config import DATA_DIR, MODELS_DIR, REPORTS_DIR, LABELS
from src.data_utils import load_raw_jigsaw, train_valid_test_split, build_dataloaders_rnn, basic_text_clean
from src.metrics import compute_classification_metrics
from src.models.rnn_models import BiLSTMClassifier, BiLSTMAttentionClassifier
from src.models.transformer_models import create_bert_base, create_distilbert

# GPU optimization settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Memory-efficient settings for Colab
BATCH_SIZE_RNN = 128      # RNN models are lightweight
BATCH_SIZE_BERT = 16      # BERT needs more memory
MAX_SEQ_LEN = 128         # Shorter sequences = faster + less memory
USE_AMP = True            # Mixed precision training (2x faster on GPU)

# Create directories
MODELS_DIR.mkdir(exist_ok=True, parents=True)
REPORTS_DIR.mkdir(exist_ok=True, parents=True)


## Load and Split Data


In [None]:
print("Loading data...")
df = load_raw_jigsaw(DATA_DIR / "jigsaw_train.csv")
train_df, valid_df, test_df = train_valid_test_split(df)

print(f"Train: {len(train_df)}, Valid: {len(valid_df)}, Test: {len(test_df)}")
print(f"\nLabel distribution:")
print(train_df[LABELS].sum())


## Phase 2: Logistic Regression Baseline (CPU)


In [None]:
print("="*60)
print("Training Logistic Regression Baseline...")
print("="*60)

start_time = time.time()

# Prepare text
train_texts = [basic_text_clean(t) for t in train_df["comment_text"].tolist()]
valid_texts = [basic_text_clean(t) for t in valid_df["comment_text"].tolist()]
y_train = train_df[LABELS].values
y_valid = valid_df[LABELS].values

# TF-IDF
print("Fitting TF-IDF...")
tfidf = TfidfVectorizer(max_features=10000, ngram_range=(1, 2), min_df=3)
X_train_tfidf = tfidf.fit_transform(train_texts)
X_valid_tfidf = tfidf.transform(valid_texts)

# Train
print("Training LogReg...")
logreg = MultiOutputClassifier(LogisticRegression(max_iter=100, C=4.0, solver='lbfgs', n_jobs=-1))
logreg.fit(X_train_tfidf, y_train)

# Predict
y_prob_logreg = np.array([clf.predict_proba(X_valid_tfidf)[:, 1] for clf in logreg.estimators_]).T

# Evaluate
metrics_logreg = compute_classification_metrics(y_valid, y_prob_logreg, threshold=0.5, label_names=LABELS)

# Save
with open(REPORTS_DIR / "logreg_baseline_metrics.json", "w") as f:
    json.dump(metrics_logreg, f, indent=2)

elapsed = time.time() - start_time
print(f"\nâœ… LogReg Complete! Time: {elapsed/60:.1f} min")
print(f"Macro F1: {metrics_logreg['macro_f1']:.4f}")
print(f"Micro F1: {metrics_logreg['micro_f1']:.4f}")


## Phase 3A: BiLSTM Baseline (GPU)


In [None]:
print("="*60)
print("Training BiLSTM Baseline...")
print("="*60)

start_time = time.time()

# Build dataloaders
train_loader, valid_loader, vocab = build_dataloaders_rnn(train_df, valid_df, max_len=100)

# Model
vocab_size = len(vocab)
model_bilstm = BiLSTMClassifier(
    vocab_size=vocab_size,
    embed_dim=128,
    hidden_dim=128,
    num_labels=len(LABELS),
    pad_idx=vocab["<pad>"],
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model_bilstm.parameters(), lr=1e-3)

# Training function
def train_epoch_bilstm(model, loader):
    model.train()
    total_loss = 0.0
    for x, y in tqdm(loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

# Eval function
@torch.no_grad()
def eval_epoch_bilstm(model, loader):
    model.eval()
    all_probs, all_labels = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = torch.sigmoid(logits)
        all_probs.append(probs.cpu().numpy())
        all_labels.append(y.cpu().numpy())
    return np.concatenate(all_labels, 0), np.concatenate(all_probs, 0)

# Train for 5 epochs
best_f1 = 0.0
for epoch in range(1, 6):
    train_loss = train_epoch_bilstm(model_bilstm, train_loader)
    y_true, y_prob_bilstm = eval_epoch_bilstm(model_bilstm, valid_loader)
    metrics_bilstm = compute_classification_metrics(y_true, y_prob_bilstm, threshold=0.5, label_names=LABELS)
    
    print(f"Epoch {epoch} - Loss: {train_loss:.4f}, Macro F1: {metrics_bilstm['macro_f1']:.4f}")
    
    if metrics_bilstm['macro_f1'] > best_f1:
        best_f1 = metrics_bilstm['macro_f1']
        torch.save({"state_dict": model_bilstm.state_dict(), "vocab": vocab}, 
                   MODELS_DIR / "bilstm_baseline.pt")
        with open(REPORTS_DIR / "bilstm_baseline_metrics.json", "w") as f:
            json.dump(metrics_bilstm, f, indent=2)

elapsed = time.time() - start_time
print(f"\nâœ… BiLSTM Complete! Time: {elapsed/60:.1f} min")
print(f"Best Macro F1: {best_f1:.4f}")


In [None]:
print("="*60)
print("Training BiLSTM + Attention...")
print("="*60)

start_time = time.time()

# Model with attention
model_bilstm_attn = BiLSTMAttentionClassifier(
    vocab_size=vocab_size,
    embed_dim=128,
    hidden_dim=128,
    num_labels=len(LABELS),
    pad_idx=vocab["<pad>"],
).to(device)

optimizer_attn = torch.optim.Adam(model_bilstm_attn.parameters(), lr=1e-3)

# Training function
def train_epoch_attn(model, loader):
    model.train()
    total_loss = 0.0
    for x, y in tqdm(loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer_attn.zero_grad()
        logits, attn = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer_attn.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

# Eval function
@torch.no_grad()
def eval_epoch_attn(model, loader):
    model.eval()
    all_probs, all_labels = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits, attn = model(x)
        probs = torch.sigmoid(logits)
        all_probs.append(probs.cpu().numpy())
        all_labels.append(y.cpu().numpy())
    return np.concatenate(all_labels, 0), np.concatenate(all_probs, 0)

# Train for 5 epochs
best_f1_attn = 0.0
for epoch in range(1, 6):
    train_loss = train_epoch_attn(model_bilstm_attn, train_loader)
    y_true, y_prob_attn = eval_epoch_attn(model_bilstm_attn, valid_loader)
    metrics_attn = compute_classification_metrics(y_true, y_prob_attn, threshold=0.5, label_names=LABELS)
    
    print(f"Epoch {epoch} - Loss: {train_loss:.4f}, Macro F1: {metrics_attn['macro_f1']:.4f}")
    
    if metrics_attn['macro_f1'] > best_f1_attn:
        best_f1_attn = metrics_attn['macro_f1']
        torch.save({"state_dict": model_bilstm_attn.state_dict(), "vocab": vocab}, 
                   MODELS_DIR / "bilstm_attention.pt")
        with open(REPORTS_DIR / "bilstm_attention_metrics.json", "w") as f:
            json.dump(metrics_attn, f, indent=2)

elapsed = time.time() - start_time
print(f"\nâœ… BiLSTM+Attention Complete! Time: {elapsed/60:.1f} min")
print(f"Best Macro F1: {best_f1_attn:.4f}")


## Phase 4: DistilBERT (GPU + Mixed Precision)


In [None]:
print("="*60)
print("Training DistilBERT (Fast + Memory Efficient)...")
print("="*60)

start_time = time.time()

# Clear GPU memory
if device.type == "cuda":
    torch.cuda.empty_cache()

# Prepare data for BERT
from torch.utils.data import Dataset

class JigsawBertDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels.astype("float32")
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        txt = str(self.texts[idx])
        enc = self.tokenizer(txt, truncation=True, padding="max_length",
                           max_length=self.max_len, return_tensors="pt")
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.float32),
        }

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

train_ds = JigsawBertDataset(train_df["comment_text"].tolist(), train_df[LABELS].values,
                              tokenizer, MAX_SEQ_LEN)
valid_ds = JigsawBertDataset(valid_df["comment_text"].tolist(), valid_df[LABELS].values,
                              tokenizer, MAX_SEQ_LEN)

train_loader_bert = DataLoader(train_ds, batch_size=BATCH_SIZE_BERT, shuffle=True, num_workers=2)
valid_loader_bert = DataLoader(valid_ds, batch_size=BATCH_SIZE_BERT*2, shuffle=False, num_workers=2)

# Model
model_distilbert = create_distilbert(len(LABELS)).to(device)

# Class weights for imbalanced labels
pos_counts = train_df[LABELS].sum().values
neg_counts = len(train_df) - pos_counts
pos_weight = torch.tensor(neg_counts / np.maximum(pos_counts, 1), dtype=torch.float32).to(device)
criterion_bert = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer_bert = torch.optim.AdamW(model_distilbert.parameters(), lr=2e-5)
total_steps = len(train_loader_bert) * 3
scheduler = get_linear_schedule_with_warmup(optimizer_bert, num_warmup_steps=int(0.1*total_steps),
                                           num_training_steps=total_steps)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

# Training function with AMP
def train_epoch_bert(model, loader):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer_bert.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion_bert(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer_bert)
        scaler.update()
        scheduler.step()
        
        total_loss += loss.item() * input_ids.size(0)
    return total_loss / len(loader.dataset)

# Eval function
@torch.no_grad()
def eval_epoch_bert(model, loader):
    model.eval()
    all_probs, all_labels = [], []
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        probs = torch.sigmoid(logits)
        all_probs.append(probs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
    return np.concatenate(all_labels, 0), np.concatenate(all_probs, 0)

# Train for 3 epochs (enough for DistilBERT)
best_f1_distilbert = 0.0
for epoch in range(1, 4):
    train_loss = train_epoch_bert(model_distilbert, train_loader_bert)
    y_true, y_prob_distilbert = eval_epoch_bert(model_distilbert, valid_loader_bert)
    metrics_distilbert = compute_classification_metrics(y_true, y_prob_distilbert, threshold=0.5, label_names=LABELS)
    
    print(f"Epoch {epoch} - Loss: {train_loss:.4f}, Macro F1: {metrics_distilbert['macro_f1']:.4f}")
    
    if metrics_distilbert['macro_f1'] > best_f1_distilbert:
        best_f1_distilbert = metrics_distilbert['macro_f1']
        torch.save(model_distilbert.state_dict(), MODELS_DIR / "distilbert_toxic.pt")
        with open(REPORTS_DIR / "distilbert_toxic_metrics.json", "w") as f:
            json.dump(metrics_distilbert, f, indent=2)

elapsed = time.time() - start_time
print(f"\nâœ… DistilBERT Complete! Time: {elapsed/60:.1f} min")
print(f"Best Macro F1: {best_f1_distilbert:.4f}")


In [None]:
print("="*60)
print("Training BERT-base (Best Performance)...")
print("="*60)

start_time = time.time()

# Clear GPU memory
if device.type == "cuda":
    torch.cuda.empty_cache()
    del model_distilbert

# Model
model_bert = create_bert_base(len(LABELS)).to(device)
optimizer_bert2 = torch.optim.AdamW(model_bert.parameters(), lr=2e-5)
scheduler2 = get_linear_schedule_with_warmup(optimizer_bert2, num_warmup_steps=int(0.1*total_steps),
                                            num_training_steps=total_steps)

# Smaller batch size for BERT (more memory)
BATCH_SIZE_BERT_LARGE = 8
train_loader_bert_large = DataLoader(train_ds, batch_size=BATCH_SIZE_BERT_LARGE, shuffle=True, num_workers=2)
valid_loader_bert_large = DataLoader(valid_ds, batch_size=BATCH_SIZE_BERT_LARGE*2, shuffle=False, num_workers=2)

# Training function
def train_epoch_bert_large(model, loader, optimizer, scheduler):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion_bert(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        total_loss += loss.item() * input_ids.size(0)
    return total_loss / len(loader.dataset)

# Train for 3 epochs
best_f1_bert = 0.0
for epoch in range(1, 4):
    train_loss = train_epoch_bert_large(model_bert, train_loader_bert_large, optimizer_bert2, scheduler2)
    y_true, y_prob_bert = eval_epoch_bert(model_bert, valid_loader_bert_large)
    metrics_bert = compute_classification_metrics(y_true, y_prob_bert, threshold=0.5, label_names=LABELS)
    
    print(f"Epoch {epoch} - Loss: {train_loss:.4f}, Macro F1: {metrics_bert['macro_f1']:.4f}")
    
    if metrics_bert['macro_f1'] > best_f1_bert:
        best_f1_bert = metrics_bert['macro_f1']
        torch.save(model_bert.state_dict(), MODELS_DIR / "bert_toxic.pt")
        with open(REPORTS_DIR / "bert_toxic_metrics.json", "w") as f:
            json.dump(metrics_bert, f, indent=2)

elapsed = time.time() - start_time
print(f"\nâœ… BERT Complete! Time: {elapsed/60:.1f} min")
print(f"Best Macro F1: {best_f1_bert:.4f}")


## ðŸ“Š Final Results Comparison


In [None]:
# Create comprehensive comparison
results = {
    "Logistic Regression": metrics_logreg,
    "BiLSTM": metrics_bilstm,
    "BiLSTM + Attention": metrics_attn,
    "DistilBERT": metrics_distilbert,
    "BERT-base": metrics_bert,
}

# Overall comparison
print("\n" + "="*80)
print("FINAL RESULTS COMPARISON")
print("="*80)

comparison = []
for model_name, metrics in results.items():
    comparison.append({
        "Model": model_name,
        "Macro F1": f"{metrics['macro_f1']:.4f}",
        "Micro F1": f"{metrics['micro_f1']:.4f}",
    })

comparison_df = pd.DataFrame(comparison)
print(comparison_df.to_string(index=False))

# Per-label comparison
print("\n" + "="*80)
print("PER-LABEL F1 SCORES")
print("="*80)

per_label_data = []
for label in LABELS:
    row = {"Label": label}
    for model_name, metrics in results.items():
        row[model_name] = f"{metrics['per_label'][label]['f1']:.4f}"
    per_label_data.append(row)

per_label_df = pd.DataFrame(per_label_data)
print(per_label_df.to_string(index=False))

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Overall F1
models = list(results.keys())
macro_f1s = [results[m]['macro_f1'] for m in models]
micro_f1s = [results[m]['micro_f1'] for m in models]

x = np.arange(len(models))
width = 0.35

ax1.bar(x - width/2, macro_f1s, width, label='Macro F1', color='steelblue')
ax1.bar(x + width/2, micro_f1s, width, label='Micro F1', color='coral')
ax1.set_xlabel('Model')
ax1.set_ylabel('F1 Score')
ax1.set_title('Overall Performance Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels(models, rotation=45, ha='right')
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Per-label heatmap
per_label_matrix = np.array([[results[m]['per_label'][l]['f1'] for l in LABELS] for m in models])
im = ax2.imshow(per_label_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
ax2.set_xticks(np.arange(len(LABELS)))
ax2.set_yticks(np.arange(len(models)))
ax2.set_xticklabels(LABELS, rotation=45, ha='right')
ax2.set_yticklabels(models)
ax2.set_title('Per-Label F1 Scores Heatmap')

# Add text annotations
for i in range(len(models)):
    for j in range(len(LABELS)):
        text = ax2.text(j, i, f'{per_label_matrix[i, j]:.3f}',
                       ha="center", va="center", color="black", fontsize=9)

plt.colorbar(im, ax=ax2)
plt.tight_layout()
plt.savefig(REPORTS_DIR / "model_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nâœ… All models trained! Results saved to: {REPORTS_DIR}")


## ðŸ’¾ Download Results from Colab


In [None]:
# Zip all results
!zip -r results.zip models/ reports/

# Download
from google.colab import files
files.download('results.zip')

print("âœ… Download complete! Extract and commit to your repo.")
