In [None]:
%matplotlib inline

In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
#!pip -q install sentence-transformers thop statsmodels
!pip -q install statsmodels

In [None]:
!pip install --upgrade gcsfs fsspec

In [None]:
#!pip -q install Datasets
# At the very top of your notebook
!pip install -q datasets transformers torch torchvision torchaudio
!pip install -q textattack nltk tqdm
!pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
!pip install -q GPUtil psutil

In [None]:
# ====
# Imports and Setup
# ====

import os
import json
import time
import random
import string
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import psutil
import GPUtil
import threading
from pathlib import Path
from datetime import datetime
from collections import defaultdict
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
import nltk
from nltk.corpus import wordnet, stopwords
import warnings

# Cell A: ablation config & run-id builder (insert right after imports)
import datetime
import json
from pathlib import Path
import numpy as np
from scipy.stats import chi2_contingency
from statsmodels.stats.contingency_tables import mcnemar



warnings.filterwarnings('ignore')

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")



# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()


# ====
# Dataset Preparation
# ====

# Download NLTK data if needed
try:
    nltk.data.find('corpora/wordnet')
    nltk.data.find('corpora/stopwords')
except LookupError:
    print("Downloading NLTK data (wordnet, stopwords)...")
    nltk.download('wordnet', quiet=True)
    nltk.download('stopwords', quiet=True)

print("Loading SST-2 dataset from Hugging Face...")
sst2_dataset = load_dataset("glue", "sst2")

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.dataset_stats = {
            'total_samples': len(texts),
            'avg_text_length': np.mean([len(str(text).split()) for text in texts[:1000]]),
            'label_distribution': np.bincount(labels) if labels else None
        }
        print(f"Dataset initialized with {self.dataset_stats['total_samples']} samples")
        print(f"Average text length: {self.dataset_stats['avg_text_length']:.1f} words")
        if self.dataset_stats['label_distribution'] is not None:
            print(f"Label distribution: {dict(enumerate(self.dataset_stats['label_distribution']))}")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[int(idx)])
        label = int(self.labels[int(idx)])
        try:
            encoding = self.tokenizer(
                text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
        except Exception as e:
            print(f"Tokenization error for sample {idx}: {e}")
            encoding = self.tokenizer(
                "",
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long)
        }
# ====
# Optimized Hyperparameters for SST-2
# ====
batch_size = 64          # Double the batch size (DistilBERT is light, this fits easily)
max_length = 48          # Reduced from 128. 99% of SST-2 sentences are < 45 tokens.
max_epochs = 5           # 10 is too many for SST-2; it usually converges by epoch 3.
learning_rate = 2e-5     # Standard for DistilBERT
accumulation_steps = 1   # With batch_size 64, you don't need accumulation on most GPUs

train_dataset = SentimentDataset(
    sst2_dataset['train']['sentence'],
    sst2_dataset['train']['label'],
    tokenizer,
    max_length=max_length
)

val_dataset = SentimentDataset(
    sst2_dataset['validation']['sentence'],
    sst2_dataset['validation']['label'],
    tokenizer,
    max_length=max_length
)

np.random.seed(42)
# Use 80% of available training data (53,879 samples)
# Keep all validation data (872 samples)
np.random.seed(42)

# ====
# Proper Train/Val/Test Split
# ====

from sklearn.model_selection import train_test_split

print("\n=== Creating Proper Train/Val/Test Split ===")

# NEW CODE (FIXED)
all_texts = list(sst2_dataset['train']['sentence']) + list(sst2_dataset['validation']['sentence'])
all_labels = list(sst2_dataset['train']['label']) + list(sst2_dataset['validation']['label'])

print(f"Total samples available: {len(all_texts)}")

# Split: 80% train, 10% val, 10% test (stratified to maintain class balance)
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    all_texts,
    all_labels,
    test_size=0.2,
    random_state=42,
    stratify=all_labels
)

val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts,
    temp_labels,
    test_size=0.5,
    random_state=42,
    stratify=temp_labels
)

print(f"\nSplit sizes:")
print(f"  Train: {len(train_texts)} samples ({len(train_texts)/len(all_texts)*100:.1f}%)")
print(f"  Val: {len(val_texts)} samples ({len(val_texts)/len(all_texts)*100:.1f}%)")
print(f"  Test: {len(test_texts)} samples ({len(test_texts)/len(all_texts)*100:.1f}%)")

# Check class distribution
print(f"\nClass distribution:")
print(f"  Train - Negative: {train_labels.count(0)} ({train_labels.count(0)/len(train_labels)*100:.1f}%), Positive: {train_labels.count(1)} ({train_labels.count(1)/len(train_labels)*100:.1f}%)")
print(f"  Val   - Negative: {val_labels.count(0)} ({val_labels.count(0)/len(val_labels)*100:.1f}%), Positive: {val_labels.count(1)} ({val_labels.count(1)/len(val_labels)*100:.1f}%)")
print(f"  Test  - Negative: {test_labels.count(0)} ({test_labels.count(0)/len(test_labels)*100:.1f}%), Positive: {test_labels.count(1)} ({test_labels.count(1)/len(test_labels)*100:.1f}%)")

# Create datasets
print("\nCreating datasets...")
train_dataset = SentimentDataset(
    train_texts,
    train_labels,
    tokenizer,
    max_length=max_length
)

val_dataset = SentimentDataset(
    val_texts,
    val_labels,
    tokenizer,
    max_length=max_length
)

test_dataset = SentimentDataset(
    test_texts,
    test_labels,
    tokenizer,
    max_length=max_length
)

# Create dataloaders
print("Creating dataloaders...")
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

print(f"\n=== Dataset Summary ===")
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Train batches: {len(train_loader)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"Test batches: {len(test_loader)}")
print(f"Batch size: {batch_size}")
print(f"Max length: {max_length}")
print(f"Effective batch size (with accumulation): {batch_size * 2}")
print(f"\nSetup complete! Ready for training.\n")

In [None]:
import collections.abc


# Small helper: build run id from config dict
def build_run_id_from_cfg(cfg):
    bits = ['BASE']
    if cfg.get("no_adv"):    bits.append('noAdv')
    if cfg.get("no_smooth"): bits.append('noSmooth')
    if cfg.get("tag"):       bits.append(cfg["tag"])
    return '_'.join(bits)

# Results aggregator path (match your notebook paths)
RESULTS_PATH = Path("results/aggregate.json")
RESULTS_PATH.parent.mkdir(exist_ok=True, parents=True)

# Safe logger that appends run results in aggregate.json
import numpy as np
import pandas as pd
import datetime
import json
import os

def sanitize_for_json(obj):
    # dict-like
    if isinstance(obj, dict):
        return {k: sanitize_for_json(v) for k, v in obj.items()}
    # lists / tuples
    if isinstance(obj, (list, tuple)):
        return [sanitize_for_json(v) for v in obj]
    # numpy scalars
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, np.bool_):
        return bool(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    # pandas types
    if isinstance(obj, pd.Timestamp):
        return obj.isoformat()
    if isinstance(obj, pd.Series):
        return sanitize_for_json(obj.to_dict())
    if isinstance(obj, pd.DataFrame):
        return sanitize_for_json(obj.to_dict(orient="list"))
    # datetimes
    if isinstance(obj, (datetime.datetime, datetime.date)):
        return obj.isoformat()
    # fallback for other types exposing .item()
    try:
        if hasattr(obj, "item"):
            val = obj.item()
            # If item() returns numpy scalar, sanitize recursively
            return sanitize_for_json(val)
    except Exception:
        pass
    return obj  # assume it's already JSON-serializable

def log_run(run_id, stats, results_path="aggregate.json"):
    # load existing results (if any)
    if os.path.exists(results_path):
        try:
            with open(results_path, "r") as f:
                results = json.load(f)
        except Exception:
            results = {}
    else:
        results = {}
    # sanitize stats so json.dump won't fail on numpy/pandas objects
    results[run_id] = sanitize_for_json(stats)
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"✔ Logged {run_id} to {results_path}")

# Default ablation matrix (4-run mini-ablation)
RUN_MATRIX = [
    {"no_adv": False, "no_smooth": False, "tag": ""},           # BASE
    {"no_adv": True,  "no_smooth": False, "tag": "noAdv"},     # no adversarial training
    {"no_adv": False, "no_smooth": True,  "tag": "noSmooth"},  # no smoothing
    {"no_adv": True,  "no_smooth": True,  "tag": "noAdv_noSmooth"}  # both off
]


# Create directories for saving results and plots
RESULTS_DIR = Path('results')
PLOTS_DIR = Path('plots')
RESULTS_DIR.mkdir(exist_ok=True)
PLOTS_DIR.mkdir(exist_ok=True)

def now_tag():
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def flatten_dict(d, parent_key='', sep='_'):
    """
    Flattens a nested dictionary, joining keys with a separator.
    It skips lists to avoid excessively long columns (e.g., epoch-by-epoch history).
    """
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.abc.MutableMapping):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        # Exclude lists from being flattened into the CSV
        elif not isinstance(v, list):
            items.append((new_key, v))
    return dict(items)



import csv
import os

# Define the path for the new comprehensive CSV results file
ALL_RUNS_CSV_PATH = RESULTS_DIR / "all_runs_results.csv"

def log_run_csv(run_id, stats_dict):
    """
    Logs the results of a single run to a comprehensive CSV file.
    """
    # Add the run_id to the dictionary for logging
    log_data = {'run_id': run_id}
    log_data.update(stats_dict)

    # Flatten the dictionary to make it CSV-friendly
    flat_data = flatten_dict(log_data)

    # Check if the file exists to determine if we need to write headers
    file_exists = os.path.isfile(ALL_RUNS_CSV_PATH)

    try:
        with open(ALL_RUNS_CSV_PATH, 'a', newline='') as f:
            # Use DictWriter to handle headers and data rows easily
            writer = csv.DictWriter(f, fieldnames=flat_data.keys())

            if not file_exists:
                writer.writeheader()  # Write headers only if the file is new

            writer.writerow(flat_data)

        print(f"✔ Logged {run_id} to {ALL_RUNS_CSV_PATH}")

    except Exception as e:
        print(f"❌ Error logging {run_id} to CSV: {e}")

In [None]:
import torch
import os

# Directory to save checkpoints (using a relative path for portability)
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, loss, filename="checkpoint.pth"):
    """Saves the model checkpoint."""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, filename)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

def load_checkpoint(model, optimizer, filename="checkpoint.pth", device='cpu'):
    """Loads the model checkpoint if it exists."""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, filename)
    start_epoch = 0
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"=> loaded checkpoint '{filename}' (resuming from epoch {start_epoch})")
    else:
        print(f"=> no checkpoint found at '{checkpoint_path}', starting from scratch.")
    return start_epoch

In [None]:
# ===================================================
# Hybrid Attack Generator: Original Performance + Modern Monitoring
# ===================================================

def get_synonyms(word):
    """Get synonyms for a word using WordNet"""
    synonyms = set()
    try:
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ')
                if synonym != word:
                    synonyms.add(synonym)
    except Exception:
        pass
    return list(synonyms)

class AdversarialAttackGenerator:
    """
    Enhanced attack generator that matches the original high-performing implementation
    with sophisticated word importance calculation and entropy-based synonym selection.
    """

    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.stopwords = set(stopwords.words('english'))

        # Original neutral words for insertion attacks
        self.neutral_words = [
            'basically', 'literally', 'actually', 'really', 'quite', 'simply',
            'nearly', 'almost', 'essentially', 'truly', 'absolutely', 'somewhat'
        ]

        # Original character substitutions
        self.char_subs = {'a': '@', 'e': '3', 'i': '1', 'o': '0', 's': '$', 'l': '1'}

        # Attack statistics
        self.attack_stats = {
            'synonym_calls': 0,
            'character_calls': 0,
            'insertion_calls': 0,
            'mixed_calls': 0,
            'total_time': 0.0
        }

    def _validate_text(self, text):
        """Validate that text is not empty and has reasonable content"""
        if not text or not text.strip():
            return False
        words = text.split()
        if len(words) < 2:
            return False
        if not any(c.isalpha() for c in text):
            return False
        return True

    def _get_word_importance(self, text, label):
        """
        Calculate word importance using gradient-based scoring.
        This is the key difference from the simplified version.
        """
        self.model.eval()
        words = text.split()

        if len(words) < 3:
            return random.choice(range(len(words))) if words else 0

        max_drop, best_idx = -float('inf'), -1

        with torch.no_grad():
            # Get original prediction probability
            encoding = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(self.device)
            original_logits = self.model(encoding['input_ids'], encoding['attention_mask'])
            original_prob = F.softmax(original_logits, dim=1)[0, label].item()

            # Test importance of each word by removal
            for i, word in enumerate(words):
                if word.lower() in self.stopwords:
                    continue

                # Create text without this word
                temp_text = ' '.join(words[:i] + words[i+1:])
                encoding = self.tokenizer(temp_text, return_tensors='pt', padding=True, truncation=True).to(self.device)
                temp_logits = self.model(encoding['input_ids'], encoding['attention_mask'])
                temp_prob = F.softmax(temp_logits, dim=1)[0, label].item()

                # Calculate probability drop (importance)
                drop = original_prob - temp_prob
                if drop > max_drop:
                    max_drop, best_idx = drop, i

        return best_idx if best_idx != -1 else random.choice(range(len(words)))

    def _synonym_attack(self, text, label):
        """
        Sophisticated synonym attack using entropy-based selection.
        This selects the synonym that maximizes model uncertainty.
        """
        words = text.split()
        idx_to_replace = self._get_word_importance(text, label)
        original_word = words[idx_to_replace]
        synonyms = get_synonyms(original_word)

        if not synonyms or (len(synonyms) == 1 and synonyms[0] == original_word):
            return text

        best_synonym, max_entropy = original_word, -1

        with torch.no_grad():
            for synonym in synonyms:
                temp_words = words[:]
                temp_words[idx_to_replace] = synonym
                temp_text = ' '.join(temp_words)

                encoding = self.tokenizer(temp_text, return_tensors='pt', padding=True, truncation=True).to(self.device)
                logits = self.model(encoding['input_ids'], encoding['attention_mask'])
                probs = F.softmax(logits, dim=1)

                # Calculate entropy (model uncertainty)
                entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item()

                if entropy > max_entropy:
                    max_entropy, best_synonym = entropy, synonym

        words[idx_to_replace] = best_synonym
        return ' '.join(words)

    def _character_attack(self, text):
        """
        Character-level perturbations with 20% word perturbation rate.
        Uses multiple operations: swap, insert, delete, substitute.
        """
        words = text.split()
        num_words_to_perturb = max(1, int(len(words) * 0.2))  # Original 20% rate

        candidate_indices = [
            i for i, word in enumerate(words)
            if len(word) > 2 and word.lower() not in self.stopwords
        ]

        if not candidate_indices:
            return text

        indices_to_perturb = random.sample(candidate_indices, min(num_words_to_perturb, len(candidate_indices)))

        for i in indices_to_perturb:
            word = words[i]
            op = random.choice(['swap', 'insert', 'delete', 'substitute'])

            if op == 'swap' and len(word) > 1:
                idx = random.randint(0, len(word) - 2)
                words[i] = word[:idx] + word[idx+1] + word[idx] + word[idx+2:]
            elif op == 'insert':
                idx = random.randint(0, len(word))
                words[i] = word[:idx] + random.choice(string.ascii_lowercase) + word[idx:]
            elif op == 'delete' and len(word) > 1:
                idx = random.randint(0, len(word) - 1)
                words[i] = word[:idx] + word[idx+1:]
            elif op == 'substitute' and len(word) > 0:
                idx = random.randint(0, len(word) - 1)
                char = word[idx]
                if char.lower() in self.char_subs:
                    words[i] = word[:idx] + self.char_subs[char.lower()] + word[idx+1:]

        return ' '.join(words)

    def _insertion_attack(self, text):
        """
        Insertion attack with 15% insertion rate using neutral words.
        """
        words = text.split()
        num_insertions = random.randint(1, max(2, int(len(words) * 0.15)))  # Original 15% rate

        for _ in range(num_insertions):
            insert_pos = random.randint(0, len(words))
            words.insert(insert_pos, random.choice(self.neutral_words))

        return ' '.join(words)

    def generate(self, texts, labels, attack_type='mixed'):
        """
        Generate adversarial examples with the original sophisticated logic.
        """
        start_time = time.time()
        adv_texts = []

        for text, label in zip(texts, labels):
            try:
                if attack_type == 'synonym':
                    adv_text = self._synonym_attack(text, label)
                    self.attack_stats['synonym_calls'] += 1
                elif attack_type == 'character':
                    adv_text = self._character_attack(text)
                    self.attack_stats['character_calls'] += 1
                elif attack_type == 'insertion':
                    adv_text = self._insertion_attack(text)
                    self.attack_stats['insertion_calls'] += 1
                elif attack_type == 'mixed':
                    # Original mixed attack logic
                    chosen_attack = random.choice([
                        lambda t, l: self._synonym_attack(t, l),
                        lambda t, l: self._character_attack(t),
                        lambda t, l: self._insertion_attack(t)
                    ])
                    adv_text = chosen_attack(text, label)
                    self.attack_stats['mixed_calls'] += 1
                else:
                    adv_text = text

                # Validate the generated text
                if self._validate_text(adv_text):
                    adv_texts.append(adv_text)
                else:
                    # Fallback to original text if generation failed
                    adv_texts.append(text)

            except Exception as e:
                print(f"Attack generation failed for text: {text[:50]}... Error: {e}")
                adv_texts.append(text)  # Fallback to original

        self.attack_stats['total_time'] += time.time() - start_time
        return adv_texts

    def augment_batch_with_adversarial_examples(self, batch, augmentation_ratio=0.6, attack_type='mixed'):
        """
        Augments a batch with adversarial examples. This method is crucial for adversarial training.
        """
        input_ids = batch['input_ids']
        labels = batch['labels']
        attention_mask = batch['attention_mask']

        num_samples = input_ids.shape[0]
        num_to_augment = int(augmentation_ratio * num_samples)

        if num_to_augment == 0:
            return input_ids, attention_mask, labels

        # Select a random subset of samples to augment
        indices_to_augment = np.random.choice(num_samples, num_to_augment, replace=False)
        indices_clean = np.array([i for i in range(num_samples) if i not in indices_to_augment])

        input_ids_to_augment = input_ids[indices_to_augment]
        labels_to_augment = labels[indices_to_augment]

        # Decode tokens to text to generate attacks
        texts_to_augment = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids_to_augment]

        # Generate adversarial texts using the existing 'generate' method
        adv_texts = self.generate(texts_to_augment, labels_to_augment.tolist(), attack_type=attack_type)

        # Re-tokenize the generated adversarial texts
        adv_encodings = self.tokenizer(
            adv_texts,
            max_length=input_ids.shape[1],
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        # --- FIX: Move the new tensors to the correct device ---
        adv_input_ids = adv_encodings['input_ids'].to(self.device)
        adv_attention_mask = adv_encodings['attention_mask'].to(self.device)
        # --- End of FIX ---

        # Combine the original clean examples with the new adversarial ones
        combined_input_ids = torch.cat((input_ids[indices_clean], adv_input_ids), dim=0)
        combined_attention_mask = torch.cat((attention_mask[indices_clean], adv_attention_mask), dim=0)
        combined_labels = torch.cat((labels[indices_clean], labels_to_augment), dim=0)

        # Shuffle the combined batch
        shuffle_indices = torch.randperm(combined_input_ids.shape[0])

        return (
            combined_input_ids[shuffle_indices],
            combined_attention_mask[shuffle_indices],
            combined_labels[shuffle_indices]
        )

    def get_attack_stats(self):
        """Get attack generation statistics"""
        return self.attack_stats.copy()

    def reset_stats(self):
        """Reset attack statistics"""
        self.attack_stats = {
            'synonym_calls': 0,
            'character_calls': 0,
            'insertion_calls': 0,
            'mixed_calls': 0,
            'total_time': 0.0
        }

    def generate_adversarial_examples(self, input_ids, labels, model, num_examples=None, attack_type='mixed', reference_len=128):
        """
        Generates adversarial examples for a given batch. This method is called by evaluate_model_robustness.
        It decodes tokens, calls the internal generate method, and re-tokenizes.
        """
        if num_examples is None:
            num_examples = input_ids.shape[0]

        # Decode tokens to text
        texts_to_attack = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids[:num_examples]]
        labels_to_attack = labels[:num_examples].tolist()

        # Generate adversarial texts
        adv_texts = self.generate(texts_to_attack, labels_to_attack, attack_type=attack_type)

        # Re-tokenize
        adv_encodings = self.tokenizer(
            adv_texts,
            max_length=reference_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return adv_encodings['input_ids'].to(self.device), adv_encodings['attention_mask'].to(self.device)

In [None]:
# ====
# Model Definitions with Resource Tracking
# ====

class ModelWithResourceTracking(nn.Module):
    def __init__(self):
        super().__init__()
        self.inference_times = []

    def forward_with_timing(self, *args, **kwargs):
        start_time = time.time()
        result = self.forward(*args, **kwargs)
        inference_time = time.time() - start_time
        self.inference_times.append(inference_time)
        return result

    def get_model_size(self):
        param_size = sum(p.numel() * p.element_size() for p in self.parameters())
        buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
        return (param_size + buffer_size) / 1024 / 1024  # MB

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Vanilla DistilBERT
class VanillaDistilBERT(ModelWithResourceTracking):
    def __init__(self, num_classes=2, dropout_rate=0.4):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.distilbert.config.hidden_size, num_classes)

    def get_input_embeddings(self):
        """Required for TRADES loss and smoothing penalty"""
        return self.distilbert.embeddings.word_embeddings

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.distilbert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        cls_embedding = self.dropout(cls_embedding)
        return self.classifier(cls_embedding)


# Adversarial Training Model
class AdversarialTrainingModel(ModelWithResourceTracking):
    def __init__(self, num_classes=2, dropout_rate=0.4):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.distilbert.config.hidden_size, num_classes)

    def get_input_embeddings(self):
        """Required for TRADES loss and smoothing penalty"""
        return self.distilbert.embeddings.word_embeddings

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.distilbert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        cls_embedding = self.dropout(cls_embedding)
        return self.classifier(cls_embedding)


# Defensive Distillation Model
class DefensiveDistillationModel(ModelWithResourceTracking):
    def __init__(self, num_classes=2, dropout_rate=0.4, temperature=3.0):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.distilbert.config.hidden_size, num_classes)
        self.temperature = temperature  # Store temperature

    def get_input_embeddings(self):
        """Required for TRADES loss and smoothing penalty"""
        return self.distilbert.embeddings.word_embeddings

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.distilbert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        cls_embedding = self.dropout(cls_embedding)
        logits = self.classifier(cls_embedding)

        # Apply temperature scaling
        return logits / self.temperature


# Input Preprocessing Model
class InputPreprocessingModel(ModelWithResourceTracking):
    def __init__(self, num_classes=2):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.denoiser = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 768)
        )
        self.classifier = nn.Linear(768, num_classes)

    def get_input_embeddings(self):
        """Required for TRADES loss and smoothing penalty"""
        return self.distilbert.embeddings.word_embeddings

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.distilbert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        embeddings = outputs.last_hidden_state[:, 0, :]
        denoised = self.denoiser(embeddings)
        return self.classifier(denoised)


# Ensemble Defense Model
class EnsembleDefenseModel(ModelWithResourceTracking):
    def __init__(self, num_classes=2, num_models=3):
        super().__init__()
        self.models = nn.ModuleList([
            VanillaDistilBERT(num_classes) for _ in range(num_models)
        ])

    def get_input_embeddings(self):
        """Use the first model's embeddings"""
        return self.models[0].get_input_embeddings()

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        outputs = []
        for model in self.models:
            output = model(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
            outputs.append(output)
        return torch.mean(torch.stack(outputs), dim=0)


# Original HAT-D Model Components
class OriginalDenoisingNetwork(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(768, 640),
            nn.LayerNorm(640),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(640, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 640),
            nn.LayerNorm(640),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(640, 768)
        )
        self.residual_scale = nn.Parameter(torch.tensor(0.3))

    def forward(self, x):
        identity = x
        x = self.encoder(x)
        x = self.decoder(x)
        return torch.tanh(identity + self.residual_scale * x)


class OriginalHybridDefenseModel(ModelWithResourceTracking):
    def __init__(self, num_classes=2):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.denoising = OriginalDenoisingNetwork()
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def get_input_embeddings(self):
        """Required for TRADES loss and smoothing penalty"""
        return self.distilbert.embeddings.word_embeddings

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.distilbert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings = self.denoising(embeddings)
        return self.classifier(embeddings)


# Enhanced HAT-D Model Components
class CharacterCNN(nn.Module):
    def __init__(self, embedding_dim=768, num_filters=192, kernel_sizes=[2, 3, 4, 5]):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(embedding_dim, num_filters, k, padding=(k-1)//2)
            for k in kernel_sizes
        ])
        self.dropout = nn.Dropout(0.2)
        self.bn = nn.BatchNorm1d(len(kernel_sizes) * num_filters)
        self.fc = nn.Linear(len(kernel_sizes) * num_filters, embedding_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        conved = [F.relu(conv(x)) for conv in self.convs]
        pooled = [F.adaptive_max_pool1d(conv, 1).squeeze(2) for conv in conved]
        cat = torch.cat(pooled, dim=1)
        cat = self.bn(cat)
        cat = self.dropout(cat)
        return self.fc(cat)


class SequenceAttention(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)

    def forward(self, x, attention_mask):
        key_padding_mask = (attention_mask == 0)
        attn_output, _ = self.attention(x, x, x, key_padding_mask=key_padding_mask)
        return attn_output


class EnhancedHATDModel(ModelWithResourceTracking):
    def __init__(self, num_classes=2, dropout_rate=0.2):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.embedding_dim = 768

        # Unfreeze last two layers + pooler (FIXED INDENTATION)
        for name, param in self.distilbert.named_parameters():
            if any(layer in name for layer in ["transformer.layer.4", "transformer.layer.5", "pooler"]):
                param.requires_grad = True
            else:
                param.requires_grad = False

        # Defense components
        self.insertion_defense = SequenceAttention(self.embedding_dim, 4)
        self.character_defense = CharacterCNN(self.embedding_dim, 128)
        self.denoiser = OriginalDenoisingNetwork()

        # Fusion gate
        self.fusion_gate = nn.Sequential(
            nn.Linear(768 * 2, 768),
            nn.Sigmoid()
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(768, 384),
            nn.LayerNorm(384),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(384, num_classes)
        )

    def get_input_embeddings(self):
        """Required for TRADES loss and smoothing penalty"""
        return self.distilbert.embeddings.word_embeddings

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.distilbert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        sequence_embeddings = outputs.last_hidden_state

        # Apply defenses
        attended_embeddings = self.insertion_defense(sequence_embeddings, attention_mask)
        insertion_features = attended_embeddings[:, 0, :]
        character_features = self.character_defense(sequence_embeddings)

        # Gated fusion
        concat = torch.cat((insertion_features, character_features), dim=1)
        gate = self.fusion_gate(concat)
        fused = gate * insertion_features + (1 - gate) * character_features

        return self.classifier(fused)


# Model registry for easy instantiation
MODEL_REGISTRY = {
    'Vanilla_DistilBERT': VanillaDistilBERT,
    'Adversarial_Training': AdversarialTrainingModel,
    'Defensive_Distillation': DefensiveDistillationModel,
    'Input_Preprocessing': InputPreprocessingModel,
    'Ensemble_Defense': EnsembleDefenseModel,
    'Original_HAT-D': OriginalHybridDefenseModel,
    'Enhanced_HAT-D': EnhancedHATDModel
}

In [None]:
# Test all models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for model_name, model_class in MODEL_REGISTRY.items():
    print(f"\nTesting {model_name}...")

    try:
        # Instantiate model
        model = model_class(num_classes=2).to(device)

        # Test get_input_embeddings
        embeddings = model.get_input_embeddings()
        print(f"  ✓ get_input_embeddings() works: {embeddings.weight.shape}")

        # Test forward with input_ids
        dummy_input_ids = torch.randint(0, 1000, (2, 128)).to(device)
        dummy_attention_mask = torch.ones(2, 128).to(device)
        output = model(input_ids=dummy_input_ids, attention_mask=dummy_attention_mask)
        print(f"  ✓ forward(input_ids) works: {output.shape}")

        # Test forward with inputs_embeds
        dummy_embeds = embeddings(dummy_input_ids)
        output = model(inputs_embeds=dummy_embeds, attention_mask=dummy_attention_mask)
        print(f"  ✓ forward(inputs_embeds) works: {output.shape}")

        print(f"  ✓ {model_name} passed all tests!")

    except Exception as e:
        print(f"  ✗ {model_name} FAILED: {e}")

In [None]:
# Print 5 random sentences with their labels
import random

random_indices = random.sample(range(len(train_dataset)), 5)
for idx in random_indices:
    text = train_dataset.texts[idx]
    label = train_dataset.labels[idx]
    print(f"Text: {text}\nLabel: {label}\n")

In [None]:
# ====
# Resource Monitoring and Latency Tracking
# ====

class ResourceMonitor:
    def __init__(self):
        self.monitoring = False
        self.metrics = defaultdict(list)
        self.monitor_thread = None
        self.epoch_start_time = None
        self.epoch_metrics = {}

    def start_epoch_monitoring(self):
        self.epoch_start_time = time.time()
        self.epoch_metrics = {
            'cpu_percent': [],
            'memory_percent': [],
            'memory_mb': [],
            'gpu_memory_mb': [],
            'gpu_utilization': []
        }
        self.monitoring = True
        self.monitor_thread = threading.Thread(target=self._monitor_resources)
        self.monitor_thread.daemon = True
        self.monitor_thread.start()

    def end_epoch_monitoring(self):
        self.monitoring = False
        if self.monitor_thread and self.monitor_thread.is_alive():
            self.monitor_thread.join(timeout=1.0)

        epoch_summary = {}
        if self.epoch_metrics['cpu_percent']:
            epoch_summary['cpu_percent_avg'] = np.mean(self.epoch_metrics['cpu_percent'])
            epoch_summary['cpu_percent_max'] = np.max(self.epoch_metrics['cpu_percent'])
        else:
            epoch_summary['cpu_percent_avg'] = 0
            epoch_summary['cpu_percent_max'] = 0

        if self.epoch_metrics['memory_mb']:
            epoch_summary['memory_mb_avg'] = np.mean(self.epoch_metrics['memory_mb'])
            epoch_summary['memory_mb_max'] = np.max(self.epoch_metrics['memory_mb'])
        else:
            epoch_summary['memory_mb_avg'] = 0
            epoch_summary['memory_mb_max'] = 0

        if self.epoch_metrics['gpu_memory_mb']:
            epoch_summary['gpu_memory_mb'] = np.max(self.epoch_metrics['gpu_memory_mb'])
            epoch_summary['gpu_memory_mb_avg'] = np.mean(self.epoch_metrics['gpu_memory_mb'])
        else:
            epoch_summary['gpu_memory_mb'] = 0
            epoch_summary['gpu_memory_mb_avg'] = 0

        if self.epoch_metrics['gpu_utilization']:
            epoch_summary['gpu_utilization_avg'] = np.mean(self.epoch_metrics['gpu_utilization'])
            epoch_summary['gpu_utilization_max'] = np.max(self.epoch_metrics['gpu_utilization'])
        else:
            epoch_summary['gpu_utilization_avg'] = 0
            epoch_summary['gpu_utilization_max'] = 0

        if self.epoch_start_time:
            epoch_summary['epoch_duration'] = time.time() - self.epoch_start_time
        else:
            epoch_summary['epoch_duration'] = 0

        return epoch_summary

    def _monitor_resources(self):
        while self.monitoring:
            try:
                cpu_percent = psutil.cpu_percent(interval=0.1)
                memory = psutil.virtual_memory()
                memory_mb = memory.used / (1024 * 1024)

                self.epoch_metrics['cpu_percent'].append(cpu_percent)
                self.epoch_metrics['memory_percent'].append(memory.percent)
                self.epoch_metrics['memory_mb'].append(memory_mb)

                try:
                    gpus = GPUtil.getGPUs()
                    if gpus:
                        gpu = gpus[0]
                        gpu_memory_mb = gpu.memoryUsed
                        gpu_utilization = gpu.load * 100

                        self.epoch_metrics['gpu_memory_mb'].append(gpu_memory_mb)
                        self.epoch_metrics['gpu_utilization'].append(gpu_utilization)
                except Exception:
                    pass

            except Exception:
                pass
            time.sleep(0.1)
# Initialize the global resource monitor
resource_monitor = ResourceMonitor()

In [None]:
# MODIFIED FOR CHECKPOINTING + GRADIENT ACCUMULATION
# ===
# Training Function
# ===

from tqdm import tqdm

class EarlyStopper:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
            return False
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

import torch.nn.functional as F

def trades_loss(model,
                input_ids,
                attention_mask,
                labels,
                device,
                beta=1.0,
                epsilon=0.03,
                alpha=0.01,
                num_iter=5):
    """
    Calculates the TRADES loss, which balances standard cross-entropy with a robustness term.
    This function generates adversarial examples internally via PGD on the embedding space.
    """
    outputs_clean = model(input_ids, attention_mask=attention_mask)
    logits_clean = outputs_clean.logits if hasattr(outputs_clean, 'logits') else outputs_clean
    loss_ce = F.cross_entropy(logits_clean, labels)

    if not hasattr(model, 'get_input_embeddings'):
        return loss_ce

    embedding_layer = model.get_input_embeddings()
    inputs_embeds = embedding_layer(input_ids).detach()
    delta = torch.zeros_like(inputs_embeds, requires_grad=True).to(device)

    for _ in range(num_iter):
        perturbed_embeds = inputs_embeds + delta
        outputs_adv = model(inputs_embeds=perturbed_embeds, attention_mask=attention_mask)
        logits_adv = outputs_adv.logits if hasattr(outputs_adv, 'logits') else outputs_adv
        loss_kl_attack = F.kl_div(
            F.log_softmax(logits_adv, dim=1),
            F.softmax(logits_clean.detach(), dim=1),
            reduction='sum'
        )
        loss_kl_attack.backward()
        delta.data = delta.data + alpha * torch.sign(delta.grad.detach())
        delta.data = torch.clamp(delta.data, -epsilon, epsilon)
        delta.grad.zero_()

    perturbed_embeds = (inputs_embeds + delta).detach()
    outputs_adv = model(inputs_embeds=perturbed_embeds, attention_mask=attention_mask)
    logits_adv = outputs_adv.logits if hasattr(outputs_adv, 'logits') else outputs_adv
    loss_kl_robustness = F.kl_div(
        F.log_softmax(logits_adv, dim=1),
        F.softmax(logits_clean, dim=1),
        reduction='batchmean'
    )
    total_loss = loss_ce + beta * loss_kl_robustness
    return total_loss

def _default_smoothing_penalty(logits, lam=1e-4):
    return lam * torch.mean(logits.pow(2))

def randomized_smoothing_penalty(model, input_ids, attention_mask, logits,
                                 num_samples=2, sigma=0.08, lam=1e-2, temp=1.0,
                                 reduction='mean', device=None):
    if device is None:
        device = next(model.parameters()).device

    if not hasattr(model, "get_input_embeddings"):
        return _default_smoothing_penalty(logits, lam=lam)

    try:
        emb_layer = model.get_input_embeddings()
        inputs_embeds = emb_layer(input_ids)
    except Exception:
        return _default_smoothing_penalty(logits, lam=lam)

    noisy_logits = []
    for _ in range(max(1, num_samples)):
        noise = torch.randn_like(inputs_embeds, device=inputs_embeds.device) * sigma
        noisy_embeds = inputs_embeds + noise
        try:
            out_noisy = model(inputs_embeds=noisy_embeds, attention_mask=attention_mask)
            logits_noisy = out_noisy.logits if hasattr(out_noisy, 'logits') else out_noisy
        except Exception:
            return _default_smoothing_penalty(logits, lam=lam)
        noisy_logits.append(logits_noisy)

    noisy_stack = torch.stack(noisy_logits, dim=0)
    noisy_mean = noisy_stack.mean(dim=0)

    logp_clean = F.log_softmax(logits / temp, dim=-1)
    p_noisy = F.softmax(noisy_mean / temp, dim=-1)
    kl_per_batch = F.kl_div(logp_clean, p_noisy, reduction='batchmean')
    loss = lam * kl_per_batch
    if reduction == 'mean':
        return loss
    elif reduction == 'sum':
        return loss * logits.shape[0]
    else:
        return loss

def train_model_comprehensive(
    model, model_name, train_loader, val_loader, epochs=max_epochs,
    use_adversarial=False, attack_generator=None, augmentation_ratio=0.65,
    attack_type='mixed', cfg=None, smoothing_penalty_fn=None
):
    """
    Comprehensive training function with modular structure, improved clarity, and robust error handling.
    Includes gradient accumulation, cosine annealing with warm restarts, and AMP (Mixed Precision).
    """
    if cfg is None:
        cfg = {}
    if smoothing_penalty_fn is None:
        smoothing_penalty_fn = _default_smoothing_penalty

    effective_use_adv = bool(use_adversarial) and (not cfg.get("no_adv", False))
    apply_smoothing = (not cfg.get("no_smooth", False))

    print(f"\n{'='*60}")
    print(f"Training {model_name} | cfg={cfg} | use_adversarial(effectively)={effective_use_adv}, smoothing={apply_smoothing}")
    print(f"{'='*60}")

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-3)

    # ── AMP scaler (no-op on CPU) ──────────────────────────────────────────
    use_amp = torch.cuda.is_available()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    if use_amp:
        print(f"  ⚡ AMP (Mixed Precision) enabled")
    else:
        print(f"  ⚠️  AMP disabled (no CUDA device found)")

    # === CHECKPOINTING: LOAD STATE IF EXISTS ===
    checkpoint_filename = f"{model.__class__.__name__}_{model_name}.pth"
    start_epoch = load_checkpoint(model, optimizer, filename=checkpoint_filename, device=device)
    # === END CHECKPOINTING ===

    criterion = nn.CrossEntropyLoss()

    # Cosine Annealing with Warm Restarts
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=2, eta_min=1e-6
    )

    early_stopper = EarlyStopper(patience=3, min_delta=0.001)

    history = {
        'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [],
        'epoch_times': [], 'learning_rates': [], 'resource_usage': []
    }

    best_val_acc = 0
    best_model_state = None

    # Gradient accumulation for effective larger batch size
    accumulation_steps = 2  # effective batch size × 2

    # === CHECKPOINTING: UPDATE LOOP TO START FROM SAVED EPOCH ===
    for epoch in range(start_epoch, epochs):
        epoch_start_time = time.time()
        resource_monitor.start_epoch_monitoring()

        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        # Initialize gradients once per epoch
        optimizer.zero_grad()

        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]', leave=False)
        for batch_idx, batch in enumerate(train_pbar):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                if effective_use_adv:
                    # TRADES loss does its own internal backward passes (PGD),
                    # so we keep it outside autocast to avoid fp16 instability
                    loss = trades_loss(model, input_ids, attention_mask, labels, device=device, beta=4.0)
                    with torch.no_grad():
                        outputs = model(input_ids, attention_mask=attention_mask)
                        logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                        predictions = torch.argmax(logits, dim=1)
                        correct = (predictions == labels).sum().item()
                        total = len(labels)
                else:
                    # ── AMP autocast for standard forward pass ─────────────
                    with torch.cuda.amp.autocast(enabled=use_amp):
                        outputs = model(input_ids, attention_mask=attention_mask)
                        logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                        ce_loss = criterion(logits, labels)

                        smoothing_loss = 0.0
                        if apply_smoothing and smoothing_penalty_fn is not None:
                            smoothing_loss = smoothing_penalty_fn(
                                model, input_ids, attention_mask, logits,
                                num_samples=3, sigma=0.1, lam=0.05, device=device
                            )

                        loss = ce_loss + smoothing_loss

                    predictions = torch.argmax(logits, dim=1)
                    correct = (predictions == labels).sum().item()
                    total = len(labels)

                # Gradient accumulation: scale loss via AMP scaler
                loss = loss / accumulation_steps
                scaler.scale(loss).backward()

                # Only update weights every accumulation_steps batches
                if (batch_idx + 1) % accumulation_steps == 0:
                    scaler.unscale_(optimizer)   # ← required before grad clip
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                # Unscale loss for logging
                train_loss += loss.item() * accumulation_steps
                train_correct += correct
                train_total += total

                current_acc = 100. * train_correct / train_total if train_total > 0 else 0.0
                train_pbar.set_postfix({
                    'Loss': f'{train_loss/(batch_idx+1):.4f}',
                    'Acc': f'{current_acc:.2f}%'
                })

            except Exception as e:
                print(f"Error in training batch {batch_idx}: {e}")
                continue

        # Handle any remaining accumulated gradients at epoch end
        if (batch_idx + 1) % accumulation_steps != 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # ---------- validation pass ---------------------
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]', leave=False)
            for batch_idx, batch in enumerate(val_pbar):
                try:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['label'].to(device)

                    with torch.cuda.amp.autocast(enabled=use_amp):
                        outputs = model(input_ids, attention_mask=attention_mask)
                        logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                        loss = criterion(logits, labels)
                    predictions = torch.argmax(logits, dim=1)

                    val_loss += loss.item()
                    val_correct += (predictions == labels).sum().item()
                    val_total += len(labels)

                    current_acc = 100. * val_correct / val_total if val_total > 0 else 0.0
                    val_pbar.set_postfix({
                        'Loss': f'{val_loss/(batch_idx+1):.4f}',
                        'Acc': f'{current_acc:.2f}%'
                    })

                except Exception as e:
                    print(f"Error in validation batch {batch_idx}: {e}")
                    continue

        epoch_resources = resource_monitor.end_epoch_monitoring()
        epoch_time = time.time() - epoch_start_time
        train_loss_avg = train_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = train_correct / train_total if train_total > 0 else 0.0
        val_loss_avg = val_loss / len(val_loader) if len(val_loader) > 0 else 0
        val_acc = val_correct / val_total if val_total > 0 else 0.0
        current_lr = optimizer.param_groups[0]['lr']

        history['train_loss'].append(train_loss_avg)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss_avg)
        history['val_acc'].append(val_acc)
        history['epoch_times'].append(epoch_time)
        history['learning_rates'].append(current_lr)
        history['resource_usage'].append(epoch_resources)

        print(f"Epoch {epoch+1}/{epochs}:")
        print(f"  Train Loss: {train_loss_avg:.4f}, Train Acc: {train_acc*100:.2f}%")
        print(f"  Val Loss: {val_loss_avg:.4f}, Val Acc: {val_acc*100:.2f}%")
        print(f"  Time: {epoch_time:.2f}s, LR: {current_lr:.2e}")
        print(f"  GPU Memory: {epoch_resources.get('gpu_memory_mb', 0):.1f}MB")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            print(f"  ✓ New best validation accuracy: {val_acc*100:.2f}%")

        # === CHECKPOINTING: SAVE STATE AFTER EACH EPOCH ===
        save_checkpoint(model, optimizer, epoch, val_loss_avg, filename=checkpoint_filename)
        # === END CHECKPOINTING ===

        # Step scheduler (no argument needed for CosineAnnealingWarmRestarts)
        scheduler.step()

        if early_stopper.early_stop(val_loss_avg):
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

        if epoch % 2 == 0:
            torch.cuda.empty_cache()
            gc.collect()

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        model.to(device)
        print(f"\nLoaded best model with validation accuracy: {best_val_acc*100:.2f}%")

    history['best_val_acc'] = best_val_acc
    history['total_epochs'] = len(history['train_loss'])

    torch.cuda.empty_cache()
    gc.collect()

    return model, history

In [None]:
# import torch.nn.functional as F

# def trades_loss(
#     model,
#     input_ids,
#     attention_mask,
#     labels,
#     device,
#     beta=1.0,
#     epsilon=0.03,
#     alpha=0.01,
#     num_iter=5,
# ):
#     """
#     Calculates the TRADES loss, which balances standard cross-entropy with a robustness term.
#     This function generates adversarial examples internally via PGD on the embedding space.
#     """
#     outputs_clean = model(input_ids, attention_mask=attention_mask)
#     logits_clean = outputs_clean.logits if hasattr(outputs_clean, 'logits') else outputs_clean
#     loss_ce = F.cross_entropy(logits_clean, labels)

#     if not hasattr(model, 'get_input_embeddings'):
#         return loss_ce

#     embedding_layer = model.get_input_embeddings()
#     inputs_embeds = embedding_layer(input_ids).detach()
#     delta = torch.zeros_like(inputs_embeds, requires_grad=True).to(device)

#     for _ in range(num_iter):
#         perturbed_embeds = inputs_embeds + delta
#         outputs_adv = model(inputs_embeds=perturbed_embeds, attention_mask=attention_mask)
#         logits_adv = outputs_adv.logits if hasattr(outputs_adv, 'logits') else outputs_adv
#         loss_kl_attack = F.kl_div(
#             F.log_softmax(logits_adv, dim=1),
#             F.softmax(logits_clean.detach(), dim=1),
#             reduction='sum'
#         )
#         loss_kl_attack.backward()
#         delta.data = delta.data + alpha * torch.sign(delta.grad.detach())
#         delta.data = torch.clamp(delta.data, -epsilon, epsilon)
#         delta.grad.zero_()

#     perturbed_embeds = (inputs_embeds + delta).detach()
#     outputs_adv = model(inputs_embeds=perturbed_embeds, attention_mask=attention_mask)
#     logits_adv = outputs_adv.logits if hasattr(outputs_adv, 'logits') else outputs_adv
#     loss_kl_robustness = F.kl_div(
#         F.log_softmax(logits_adv, dim=1),
#         F.softmax(logits_clean, dim=1),
#         reduction='batchmean'
#     )
#     total_loss = loss_ce + beta * loss_kl_robustness
#     return total_loss

# def _default_smoothing_penalty(logits, lam=1e-4):
#     return lam * torch.mean(logits.pow(2))

# def randomized_smoothing_penalty(
#     model, input_ids, attention_mask, logits,
#     num_samples=2, sigma=0.08, lam=1e-2, temp=1.0,
#     reduction='mean', device=None
# ):
#     if device is None:
#         device = next(model.parameters()).device

#     if not hasattr(model, "get_input_embeddings"):
#         return _default_smoothing_penalty(logits, lam=lam)

#     try:
#         emb_layer = model.get_input_embeddings()
#         inputs_embeds = emb_layer(input_ids)
#     except Exception:
#         return _default_smoothing_penalty(logits, lam=lam)

#     noisy_logits = []
#     for _ in range(max(1, num_samples)):
#         noise = torch.randn_like(inputs_embeds, device=inputs_embeds.device) * sigma
#         noisy_embeds = inputs_embeds + noise
#         try:
#             out_noisy = model(inputs_embeds=noisy_embeds, attention_mask=attention_mask)
#             logits_noisy = out_noisy.logits if hasattr(out_noisy, 'logits') else out_noisy
#         except Exception:
#             return _default_smoothing_penalty(logits, lam=lam)
#         noisy_logits.append(logits_noisy)

#     noisy_stack = torch.stack(noisy_logits, dim=0)
#     noisy_mean = noisy_stack.mean(dim=0)

#     logp_clean = F.log_softmax(logits / temp, dim=-1)
#     p_noisy = F.softmax(noisy_mean / temp, dim=-1)
#     kl_per_batch = F.kl_div(logp_clean, p_noisy, reduction='batchmean')
#     loss = lam * kl_per_batch
#     if reduction == 'mean':
#         return loss
#     elif reduction == 'sum':
#         return loss * logits.shape[0]
#     else:
#         return loss

# def train_model_comprehensive(
#     model, model_name, train_loader, val_loader, epochs=max_epochs,
#     use_adversarial=False, attack_generator=None, augmentation_ratio=0.65,
#     attack_type='mixed', cfg=None, smoothing_penalty_fn=None
# ):
#     """
#     Comprehensive training function with modular structure, improved clarity, and robust error handling.
#     Includes gradient accumulation and cosine annealing with warm restarts.
#     """
#     if cfg is None:
#         cfg = {}
#     if smoothing_penalty_fn is None:
#         smoothing_penalty_fn = _default_smoothing_penalty

#     effective_use_adv = bool(use_adversarial) and (not cfg.get("no_adv", False))
#     apply_smoothing = (not cfg.get("no_smooth", False))

#     print(f"\n{'='*60}")
#     print(f"Training {model_name} | cfg={cfg} | use_adversarial(effectively)={effective_use_adv}, smoothing={apply_smoothing}")
#     print(f"{'='*60}")

#     model = model.to(device)
#     optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-3)

#     # === CHECKPOINTING: LOAD STATE IF EXISTS ===
#     checkpoint_filename = f"{model.__class__.__name__}_{model_name}.pth"
#     start_epoch = load_checkpoint(model, optimizer, filename=checkpoint_filename, device=device)
#     # === END CHECKPOINTING ===

#     criterion = nn.CrossEntropyLoss()

#     # Cosine Annealing with Warm Restarts
#     scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
#         optimizer, T_0=5, T_mult=2, eta_min=1e-6
#     )

#     early_stopper = EarlyStopper(patience=3, min_delta=0.001)

#     history = {
#         'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [],
#         'epoch_times': [], 'learning_rates': [], 'resource_usage': []
#     }

#     best_val_acc = 0
#     best_model_state = None

#     # Gradient accumulation for effective larger batch size
#     accumulation_steps = 2  # effective batch size × 2

#     # === CHECKPOINTING: UPDATE LOOP TO START FROM SAVED EPOCH ===
#     for epoch in range(start_epoch, epochs):
#         epoch_start_time = time.time()
#         resource_monitor.start_epoch_monitoring()

#         model.train()
#         train_loss = 0.0
#         train_correct = 0
#         train_total = 0

#         # Initialize gradients once per epoch
#         optimizer.zero_grad()

#         train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]', leave=False)
#         for batch_idx, batch in enumerate(train_pbar):
#             try:
#                 input_ids = batch['input_ids'].to(device)
#                 attention_mask = batch['attention_mask'].to(device)
#                 labels = batch['label'].to(device)

#                 if effective_use_adv:
#                     loss = trades_loss(model, input_ids, attention_mask, labels, device=device, beta=4.0)
#                     with torch.no_grad():
#                         outputs = model(input_ids, attention_mask=attention_mask)
#                         logits = outputs.logits if hasattr(outputs, 'logits') else outputs
#                         predictions = torch.argmax(logits, dim=1)
#                         correct = (predictions == labels).sum().item()
#                         total = len(labels)
#                 else:
#                     outputs = model(input_ids, attention_mask=attention_mask)
#                     logits = outputs.logits if hasattr(outputs, 'logits') else outputs
#                     ce_loss = criterion(logits, labels)

#                     smoothing_loss = 0.0
#                     if apply_smoothing and smoothing_penalty_fn is not None:
#                         smoothing_loss = smoothing_penalty_fn(
#                             model, input_ids, attention_mask, logits,
#                             num_samples=3, sigma=0.1, lam=0.05, device=device
#                         )

#                     loss = ce_loss + smoothing_loss

#                     predictions = torch.argmax(logits, dim=1)
#                     correct = (predictions == labels).sum().item()
#                     total = len(labels)

#                 # Gradient accumulation: scale loss
#                 loss = loss / accumulation_steps
#                 loss.backward()

#                 # Only update weights every accumulation_steps batches
#                 if (batch_idx + 1) % accumulation_steps == 0:
#                     torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#                     optimizer.step()
#                     optimizer.zero_grad()

#                 # Unscale loss for logging
#                 train_loss += loss.item() * accumulation_steps
#                 train_correct += correct
#                 train_total += total

#                 current_acc = 100. * train_correct / train_total if train_total > 0 else 0.0
#                 train_pbar.set_postfix({
#                     'Loss': f'{train_loss/(batch_idx+1):.4f}',
#                     'Acc': f'{current_acc:.2f}%'
#                 })

#             except Exception as e:
#                 print(f"Error in training batch {batch_idx}: {e}")
#                 continue

#         # Handle any remaining accumulated gradients at epoch end
#         if (batch_idx + 1) % accumulation_steps != 0:
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#             optimizer.step()
#             optimizer.zero_grad()

#         # ---------- validation pass ---------------------
#         model.eval()
#         val_loss = 0.0
#         val_correct = 0
#         val_total = 0

#         with torch.no_grad():
#             val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]', leave=False)
#             for batch_idx, batch in enumerate(val_pbar):
#                 try:
#                     input_ids = batch['input_ids'].to(device)
#                     attention_mask = batch['attention_mask'].to(device)
#                     labels = batch['label'].to(device)

#                     outputs = model(input_ids, attention_mask=attention_mask)
#                     logits = outputs.logits if hasattr(outputs, 'logits') else outputs
#                     loss = criterion(logits, labels)
#                     predictions = torch.argmax(logits, dim=1)

#                     val_loss += loss.item()
#                     val_correct += (predictions == labels).sum().item()
#                     val_total += len(labels)

#                     current_acc = 100. * val_correct / val_total if val_total > 0 else 0.0
#                     val_pbar.set_postfix({
#                         'Loss': f'{val_loss/(batch_idx+1):.4f}',
#                         'Acc': f'{current_acc:.2f}%'
#                     })

#                 except Exception as e:
#                     print(f"Error in validation batch {batch_idx}: {e}")
#                     continue

#         epoch_resources = resource_monitor.end_epoch_monitoring()
#         epoch_time = time.time() - epoch_start_time
#         train_loss_avg = train_loss / len(train_loader) if len(train_loader) > 0 else 0
#         train_acc = train_correct / train_total if train_total > 0 else 0.0
#         val_loss_avg = val_loss / len(val_loader) if len(val_loader) > 0 else 0
#         val_acc = val_correct / val_total if val_total > 0 else 0.0
#         current_lr = optimizer.param_groups[0]['lr']

#         history['train_loss'].append(train_loss_avg)
#         history['train_acc'].append(train_acc)
#         history['val_loss'].append(val_loss_avg)
#         history['val_acc'].append(val_acc)
#         history['epoch_times'].append(epoch_time)
#         history['learning_rates'].append(current_lr)
#         history['resource_usage'].append(epoch_resources)

#         print(f"Epoch {epoch+1}/{epochs}:")
#         print(f"  Train Loss: {train_loss_avg:.4f}, Train Acc: {train_acc*100:.2f}%")
#         print(f"  Val Loss: {val_loss_avg:.4f}, Val Acc: {val_acc*100:.2f}%")
#         print(f"  Time: {epoch_time:.2f}s, LR: {current_lr:.2e}")
#         print(f"  GPU Memory: {epoch_resources.get('gpu_memory_mb', 0):.1f}MB")

#         if val_acc > best_val_acc:
#             best_val_acc = val_acc
#             best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
#             print(f"  ✓ New best validation accuracy: {val_acc*100:.2f}%")

#         # === CHECKPOINTING: SAVE STATE AFTER EACH EPOCH ===
#         save_checkpoint(model, optimizer, epoch, val_loss_avg, filename=checkpoint_filename)
#         # === END CHECKPOINTING ===

#         # Step scheduler (no argument needed for CosineAnnealingWarmRestarts)
#         scheduler.step()

#         if early_stopper.early_stop(val_loss_avg):
#             print(f"Early stopping triggered at epoch {epoch+1}")
#             break

#         if epoch % 2 == 0:
#             torch.cuda.empty_cache()
#             gc.collect()

#     if best_model_state is not None:
#         model.load_state_dict(best_model_state)
#         model.to(device)
#         print(f"\nLoaded best model with validation accuracy: {best_val_acc*100:.2f}%")

#     history['best_val_acc'] = best_val_acc
#     history['total_epochs'] = len(history['train_loss'])

#     torch.cuda.empty_cache()
#     gc.collect()

#     return model, history

# Testing the Models

In [None]:
# # ====
# # Train Baseline Model (No Defenses)
# # ====

# print("\n" + "="*80)
# print("TRAINING BASELINE MODEL (Vanilla DistilBERT)")
# print("="*80 + "\n")

# # Create model
# model_baseline = VanillaDistilBERT(num_classes=2).to(device)

# print(f"Model parameters: {model_baseline.count_parameters():,}")
# print(f"Model size: {model_baseline.get_model_size():.2f} MB\n")

# # Training configuration
# config_baseline = {
#     "no_adv": True,      # No adversarial training
#     "no_smooth": True,   # No smoothing
#     "tag": "baseline"
# }

# # Train the model
# model_baseline, history_baseline = train_model_comprehensive(
#     model=model_baseline,
#     model_name="Vanilla_DistilBERT",
#     train_loader=train_loader,
#     val_loader=val_loader,
#     epochs=5,  # Start with 5 epochs
#     use_adversarial=False,
#     cfg=config_baseline
# )

# # Print results
# print("\n" + "="*80)
# print("BASELINE TRAINING COMPLETE")
# print("="*80)
# print(f"Best Validation Accuracy: {history_baseline['best_val_acc']*100:.2f}%")
# print(f"Final Training Loss: {history_baseline['train_losses'][-1]:.4f}")
# print(f"Final Validation Loss: {history_baseline['val_losses'][-1]:.4f}")
# print("="*80 + "\n")

In [None]:
# import matplotlib.pyplot as plt
# import seaborn as sns
# import datetime
# from pathlib import Path

# # Assuming RESULTS_DIR and PLOTS_DIR are defined globally (from akDV7zGLUyBW)
# # and max_epochs is defined globally (from Zy1z7OZ8Ka9g)

# def now_tag():
#     return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# def save_plot(fig, plot_name, prefix="experiment"):
#     timestamp = now_tag()
#     filename = f"{prefix}_{plot_name}_{timestamp}.png"
#     # Ensure RESULTS_DIR and PLOTS_DIR are accessible, assuming they are defined in a previous cell
#     # or setting them here for independence if needed.
#     # For this fix, we assume they are already global.
#     global PLOTS_DIR # Access the global variable
#     if not 'PLOTS_DIR' in globals():
#         # Fallback if PLOTS_DIR is not defined globally (e.g., in a fresh session)
#         PLOTS_DIR = Path('plots')
#         PLOTS_DIR.mkdir(exist_ok=True)

#     filepath = PLOTS_DIR / filename
#     try:
#         fig.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
#         print(f"✅ Plot saved to: {filepath}")
#         return str(filepath)
#     except Exception as e:
#         print(f"❌ Error saving plot: {e}")
#         return None

# def plot_individual_model_curves(model_name, data, save_individual=True):
#     """Plot training curves for a single model"""
#     epochs = range(1, data['epochs'] + 1)

#     fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

#     # Plot Loss curves
#     ax1.plot(epochs, data['train_losses'], 'b-', label='Training Loss',
#              linewidth=2.5, marker='o', markersize=6)
#     ax1.plot(epochs, data['val_losses'], 'r-', label='Validation Loss',
#              linewidth=2.5, marker='s', markersize=6)
#     ax1.set_title(f'{model_name} - Training & Validation Loss',
#                   fontsize=14, fontweight='bold')
#     ax1.set_xlabel('Epoch', fontsize=12)
#     ax1.set_ylabel('Loss', fontsize=12)
#     ax1.legend(fontsize=11)
#     ax1.grid(True, alpha=0.3)
#     ax1.set_xlim(1, data['epochs'])

#     # Plot Accuracy curves
#     ax2.plot(epochs, data['train_accs'], 'b-', label='Training Accuracy',
#              linewidth=2.5, marker='o', markersize=6)
#     ax2.plot(epochs, data['val_accs'], 'r-', label='Validation Accuracy',
#              linewidth=2.5, marker='s', markersize=6)
#     ax2.set_title(f'{model_name} - Training & Validation Accuracy',
#                   fontsize=14, fontweight='bold')
#     ax2.set_xlabel('Epoch', fontsize=12)
#     ax2.set_ylabel('Accuracy', fontsize=12)
#     ax2.legend(fontsize=11)
#     ax2.grid(True, alpha=0.3)
#     ax2.set_xlim(1, data['epochs'])
#     ax2.set_ylim(0.75, 1.0)  # Focus on the relevant accuracy range

#     plt.tight_layout()

#     if save_individual:
#         filename = f"{model_name.replace(' ', '_').replace('(', '').replace(')', '')}_training_curves.png"
#         save_plot(fig, filename) # Use the now local save_plot
#         print(f"Individual plot saved: {filename}")

#     plt.show()

# def plot_training_history(history, run_id, save_individual=True):
#     """
#     Map your `history` produced by train_model_comprehensive into
#     the schema expected by plot_individual_model_curves and display/save it.
#     """
#     # history keys used earlier: 'train_loss','train_acc','val_loss','val_acc','total_epochs'
#     # global max_epochs # Assuming max_epochs is globally defined. If not, pass it as argument.

#     epochs = history.get('total_epochs', len(history.get('train_loss', [])))
#     data = {
#         'epochs': epochs,
#         'train_losses': history.get('train_loss', [])[:epochs],
#         'val_losses': history.get('val_loss', [])[:epochs],
#         'train_accs': history.get('train_acc', [])[:epochs],
#         'val_accs': history.get('val_acc', [])[:epochs]
#     }
#     # Use run_id as model name in filename/title
#     plot_individual_model_curves(run_id, data, save_individual=save_individual)

# # Check if history_baseline exists and is valid before plotting
# if 'history_baseline' in globals() and history_baseline and isinstance(history_baseline, dict) and history_baseline.get('train_loss'):
#     print(f"Plotting training history for {run_id_for_plot}...")
#     plot_training_history(history_baseline, run_id_for_plot)
# else:
#     print("history_baseline is either empty, not a dictionary, or does not contain 'train_loss' data.")
#     print("Please ensure a training run has successfully completed and populated 'history_baseline'.")


In [None]:
# # Add near the class definition (module/global scope)
# def augment_batch_with_adversarial_examples(batch, attack_generator, augmentation_ratio=0.65, attack_type='mixed'):
#     # Backwards-compatible wrapper: delegate to the instance method
#     return attack_generator.augment_batch_with_adversarial_examples(
#         batch, augmentation_ratio=augmentation_ratio, attack_type=attack_type
#    )

In [None]:
def collect_predictions_on_loader(model, loader, device=device):
    model.eval()
    y_list, p_list = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            y_list.append(labels.cpu().numpy())
            p_list.append(preds)
    y_arr = np.concatenate(y_list, axis=0)
    p_arr = np.concatenate(p_list, axis=0)
    return y_arr, p_arr

def collect_adversarial_predictions(model, loader, attack_generator, attack_type='pgd', device=device):
    """
    Given an attack_generator function (same interface as used in training),
    create full adversarial batches (augmentation_ratio=1.0) and return preds.
    If you have a different eval attack function (e.g., eval_attack), you can
    replace this call with yours.
    """
    model.eval()
    y_list, p_list = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Use your augment_batch_with_adversarial_examples but force augmentation_ratio=1.0
            if attack_generator is None:
                # fallback to clean if no attack generator provided
                outputs = model(input_ids, attention_mask=attention_mask)
            else:
                input_ids_adv, attention_mask_adv, labels_adv = attack_generator.augment_batch_with_adversarial_examples(
                            {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels},
                    augmentation_ratio=1.0,
                    attack_type=attack_type
                )
                input_ids_adv = input_ids_adv.to(device)
                attention_mask_adv = attention_mask_adv.to(device)
                outputs = model(input_ids_adv, attention_mask=attention_mask_adv)

            logits = outputs.logits if hasattr(outputs, 'logits') else outputs
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            y_list.append(labels.cpu().numpy())
            p_list.append(preds)
    y_arr = np.concatenate(y_list, axis=0)
    p_arr = np.concatenate(p_list, axis=0)
    return y_arr, p_arr

In [None]:
# ====
# Evaluation and Robustness Testing
# ====

def evaluate_model_robustness(model, val_loader, attack_generator, model_name):
    print(f"\nEvaluating {model_name} Robustness...")

    model.eval()
    attack_types = ['clean', 'synonym', 'character', 'insertion', 'mixed']
    results = {}

    # Collect all validation data for evaluation
    all_texts = []
    all_labels = []
    for batch in val_loader:
        all_texts.extend([tokenizer.decode(ids, skip_special_tokens=True) for ids in batch['input_ids']])
        all_labels.extend(batch['label'].tolist())

    all_texts = np.array(all_texts)
    all_labels = np.array(all_labels)

    for attack_type in attack_types:
        print(f"Testing {attack_type} attack...")
        predictions = []
        latencies = []
        correct = 0
        total = 0

        batch_size_eval = val_loader.batch_size
        for i in range(0, len(all_texts), batch_size_eval):
            batch_texts = all_texts[i:i+batch_size_eval]
            batch_labels = all_labels[i:i+batch_size_eval]

            if attack_type == 'clean':
                enc = tokenizer(batch_texts.tolist(), return_tensors='pt', padding='max_length', truncation=True, max_length=max_length)
                input_ids = enc['input_ids'].to(device)
                attention_mask = enc['attention_mask'].to(device)
            else:
                # Generate adversarial examples
                enc_clean = tokenizer(batch_texts.tolist(), return_tensors='pt', padding='max_length', truncation=True, max_length=max_length)
                input_ids_clean = enc_clean['input_ids'].to(device)
                attention_mask_clean = enc_clean['attention_mask'].to(device)
                labels_tensor = torch.tensor(batch_labels).to(device)

                input_ids, attention_mask = attack_generator.generate_adversarial_examples(
                    input_ids_clean, labels_tensor, model, num_examples=len(batch_labels), attack_type=attack_type, reference_len=max_length
                )

            start_time = time.time()
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                preds = torch.argmax(logits, dim=1).cpu().numpy()
            end_time = time.time()

            predictions.extend(preds)
            correct += (preds == batch_labels).sum()
            total += len(batch_labels)
            latencies.append(end_time - start_time)

        accuracy = correct / total if total > 0 else 0
        f1 = f1_score(all_labels, predictions, average='weighted')

        results[attack_type] = {
            'accuracy': accuracy,
            'f1': f1,
            'latency_avg': np.mean(latencies),
            'latency_total': np.sum(latencies),
            'predictions': predictions,
            'labels': all_labels.tolist()
        }
        print(f"{attack_type.capitalize()} - Accuracy: {accuracy:.4f}, F1: {f1:.4f}, Avg Latency: {np.mean(latencies):.4f}s")

    # Calculate robustness metrics
    clean_acc = results['clean']['accuracy']
    robustness_scores = {}
    for attack_type in ['synonym', 'character', 'insertion', 'mixed']:
        attack_acc = results[attack_type]['accuracy']
        robustness_drop = clean_acc - attack_acc
        robustness_scores[f'{attack_type}_drop'] = robustness_drop
        robustness_scores[f'{attack_type}_retention'] = attack_acc / clean_acc if clean_acc > 0 else 0

    avg_robustness = np.mean([results[att]['accuracy'] for att in ['synonym', 'character', 'insertion', 'mixed']])
    robustness_scores['overall_robustness'] = avg_robustness
    robustness_scores['robustness_drop'] = clean_acc - avg_robustness

    results['robustness_metrics'] = robustness_scores

    return results

In [None]:
# ====
# Utility Functions for Saving and Plotting
# ====

def save_results(results_dict, filename_prefix="experiment"):
    timestamp = now_tag()
    filename = f"{filename_prefix}_{timestamp}.json"
    filepath = RESULTS_DIR / filename

    def convert_numpy(obj):
        if hasattr(obj, 'tolist'):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_numpy(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy(item) for item in obj]
        else:
            return obj

    serializable_results = convert_numpy(results_dict)

    try:
        with open(filepath, 'w') as f:
            json.dump(serializable_results, f, indent=2, default=str)
        print(f"✅ Results saved to: {filepath}")
        return str(filepath)
    except Exception as e:
        print(f"❌ Error saving results: {e}")
        return None

def save_plot(fig, plot_name, prefix="experiment"):
    timestamp = now_tag()
    filename = f"{prefix}_{plot_name}_{timestamp}.png"
    filepath = PLOTS_DIR / filename
    try:
        fig.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
        print(f"✅ Plot saved to: {filepath}")
        return str(filepath)
    except Exception as e:
        print(f"❌ Error saving plot: {e}")
        return None

def plot_training_history(history, model_name):
    train_losses = history.get('train_loss', [])
    val_losses = history.get('val_loss', [])
    train_accs = history.get('train_acc', [])
    val_accs = history.get('val_acc', [])

    epochs = range(1, max(1, len(train_losses)) + 1)

    fig, ax = plt.subplots(1, 2, figsize=(14, 5))

    # Loss
    ax[0].plot(epochs, train_losses, label='Train Loss', marker='o')
    ax[0].plot(epochs, val_losses, label='Val Loss', marker='s')
    ax[0].set_title(f'{model_name} Loss')
    ax[0].set_xlabel('Epoch')
    ax[0].set_ylabel('Loss')
    ax[0].legend()
    ax[0].grid(True, alpha=0.3)

    # Accuracy: detect scale (0-1 vs 0-100) and set sensible y-limits
    all_accs = list(train_accs) + list(val_accs)
    if len(all_accs) == 0:
        print("No accuracy history to plot.")
        return

    max_acc = max(all_accs)
    min_acc = min(all_accs)
    # If accuracies look like percentages (>1.5), treat as 0-100
    if max_acc > 1.5:
        ylabel = 'Accuracy (%)'
        pad = max(1.0, (max_acc - min_acc) * 0.1)
        ymin, ymax = max(0, min_acc - pad), min(100, max_acc + pad)
    else:
        ylabel = 'Accuracy'
        pad = max(0.01, (max_acc - min_acc) * 0.1)
        ymin, ymax = max(0.0, min_acc - pad), min(1.0, max_acc + pad)

    ax[1].plot(epochs, train_accs, label='Train Accuracy', marker='o')
    ax[1].plot(epochs, val_accs, label='Val Accuracy', marker='s')
    ax[1].set_title(f'{model_name} Accuracy')
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel(ylabel)
    ax[1].legend()
    ax[1].grid(True, alpha=0.3)
    ax[1].set_ylim(ymin, ymax)

    plt.tight_layout()
    save_plot(fig, f"{model_name}_training_history")
    plt.show()


# ====
# Additional Plots and Statistical Tests
# ====

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

def perform_mcnemar_test(preds1, preds2, labels):
    correct1 = (preds1 == labels)
    correct2 = (preds2 == labels)
    n01 = np.sum((correct1 == 1) & (correct2 == 0))
    n10 = np.sum((correct1 == 0) & (correct2 == 1))
    if n01 + n10 == 0:
        return 1.0
    mcnemar_stat = (abs(n01 - n10) - 1)**2 / (n01 + n10)
    return 1 - stats.chi2.cdf(mcnemar_stat, 1)

# Calculate average robust accuracy and robustness drop
    avg_robust_acc = [(s+c+i+m)/4 for s,c,i,m in zip(synonym_acc, character_acc, insertion_acc, mixed_acc)]
    robustness_drop = [clean - avg for clean, avg in zip(clean_acc, avg_robust_acc)]

    return {
        'models': models,
        'clean_acc': clean_acc,
        'synonym_acc': synonym_acc,
        'character_acc': character_acc,
        'insertion_acc': insertion_acc,
        'mixed_acc': mixed_acc,
        'avg_robust_acc': avg_robust_acc,
        'robustness_drop': robustness_drop,
        'trainable_params': trainable_params,
        'total_params': total_params
    }

from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# ---- Provided plotting functions (unchanged, copied from your code) ----

def plot_individual_performance(data, save_path='individual_performance_clean.png'):
    """Plot individual model performance across different attack types"""
    fig, ax = plt.subplots(figsize=(14, 8))

    attack_types = ['Clean', 'Synonym', 'Character', 'Insertion', 'Mixed']
    x = np.arange(len(data['models']))
    width = 0.15
    colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6']

    for i, (attack, color) in enumerate(zip(attack_types, colors)):
        if attack == 'Clean':
            values = data['clean_acc']
        elif attack == 'Synonym':
            values = data['synonym_acc']
        elif attack == 'Character':
            values = data['character_acc']
        elif attack == 'Insertion':
            values = data['insertion_acc']
        else:  # Mixed
            values = data['mixed_acc']

        bars = ax.bar(x + i*width, values, width, label=attack, color=color, alpha=0.8)

        for bar, val in zip(bars, values):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,
                   f'{val:.3f}', ha='center', va='bottom', fontsize=8, rotation=0)

    ax.set_xlabel('Models', fontsize=12, fontweight='bold')
    ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax.set_title('Model Performance Across Different Attack Types', fontsize=14, fontweight='bold', pad=20)
    ax.set_xticks(x + width * 2)
    ax.set_xticklabels([m.replace('_', '\n') for m in data['models']], rotation=45, ha='right')
    ax.legend(loc='upper right', bbox_to_anchor=(1, 1))
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0.7, 1.0)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_clean_vs_robust(data, save_path='clean_vs_robust_accuracy.png'):
    fig, ax = plt.subplots(figsize=(12, 8))

    x = np.arange(len(data['models']))
    width = 0.35

    hatch_clean = '//'
    hatch_robust = '\\\\'

    bars1 = ax.bar(x - width/2, data['clean_acc'], width, label='Clean Accuracy',
                   color='#3498db', alpha=0.8, edgecolor='black', linewidth=0.5, hatch=hatch_clean)
    bars2 = ax.bar(x + width/2, data['avg_robust_acc'], width, label='Average Robust Accuracy',
                   color='#e74c3c', alpha=0.8, edgecolor='black', linewidth=0.5, hatch=hatch_robust)

    for bar, val in zip(bars1, data['clean_acc']):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,
               f'{val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

    for bar, val in zip(bars2, data['avg_robust_acc']):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,
               f'{val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

    ax.set_xlabel('Models', fontsize=12, fontweight='bold')
    ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax.set_title('Clean vs. Average Robust Accuracy Comparison', fontsize=14, fontweight='bold', pad=20)
    ax.set_xticks(x)
    ax.set_xticklabels([m.replace('_', '\n') for m in data['models']], rotation=45, ha='right')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0.75, 0.95)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_robustness_drop(data, save_path='robustness_drop_analysis.png'):
    fig, ax = plt.subplots(figsize=(12, 8))

    colors = []
    hatches = []
    for model in data['models']:
        if 'Enhanced_HAT-D' in model:
            colors.append('#f39c12')  # Orange for our method
            hatches.append('//')
        elif 'Adversarial_Training' in model:
            colors.append('#2ecc71')  # Green for best performing
            hatches.append('\\\\')
        elif 'Ensemble_Defense' in model:
            colors.append('#e74c3c')  # Red for ensemble
            hatches.append('xx')
        else:
            colors.append('#3498db')  # Blue for baseline methods
            hatches.append('..')

    bars = ax.bar(range(len(data['models'])), data['robustness_drop'], color=colors, alpha=0.8,
                  edgecolor='black', linewidth=0.5)

    for bar, hatch in zip(bars, hatches):
        bar.set_hatch(hatch)

    for bar, val in zip(bars, data['robustness_drop']):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.002,
               f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

    ax.set_xlabel('Models', fontsize=12, fontweight='bold')
    ax.set_ylabel('Robustness Drop (Clean Acc. - Avg Robust Acc.)', fontsize=12, fontweight='bold')
    ax.set_title('Model Robustness Drop Analysis', fontsize=14, fontweight='bold', pad=20)
    ax.set_xticks(range(len(data['models'])))
    ax.set_xticklabels([m.replace('_', '\n') for m in data['models']], rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0, max(data['robustness_drop']) * 1.15)

    legend_elements = [
        Patch(facecolor='#3498db', label='Baseline Methods', hatch='..'),
        Patch(facecolor='#2ecc71', label='Best Performing', hatch='\\\\'),
        Patch(facecolor='#f39c12', label='Our Method', hatch='//'),
        Patch(facecolor='#e74c3c', label='High Parameter Count', hatch='xx')
    ]
    ax.legend(handles=legend_elements, loc='upper right')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_efficiency_analysis(data, save_path='model_efficiency_analysis.png'):
    fig, ax = plt.subplots(figsize=(12, 8))

    trainable_params_millions = [p/1e6 for p in data['trainable_params']]

    colors = []
    hatches = []
    for model in data['models']:
        if 'Enhanced_HAT-D' in model:
            colors.append('#f39c12')  # Orange for our method
            hatches.append('//')
        elif 'Ensemble_Defense' in model:
            colors.append('#e74c3c')  # Red for high parameter count
            hatches.append('xx')
        else:
            colors.append('#3498db')  # Blue for others
            hatches.append('..')

    bars = ax.barh(range(len(data['models'])), trainable_params_millions, color=colors,
                   alpha=0.8, edgecolor='black', linewidth=0.5)

    for bar, hatch in zip(bars, hatches):
        bar.set_hatch(hatch)

    for i, (bar, val) in enumerate(zip(bars, trainable_params_millions)):
        width = bar.get_width()
        ax.text(width + 2, bar.get_y() + bar.get_height()/2.,
               f'{val:.1f}M', ha='left', va='center', fontsize=11, fontweight='bold')

    ax.set_ylabel('Models', fontsize=12, fontweight='bold')
    ax.set_xlabel('Trainable Parameters (Millions)', fontsize=12, fontweight='bold')
    ax.set_title('Model Efficiency Analysis (Trainable Parameters)', fontsize=14, fontweight='bold', pad=20)
    ax.set_yticks(range(len(data['models'])))
    ax.set_yticklabels([m.replace('_', ' ') for m in data['models']])
    ax.grid(True, alpha=0.3, axis='x')
    ax.set_xlim(0, max(trainable_params_millions) * 1.15)

    legend_elements = [
        Patch(facecolor='#3498db', label='Standard Models', hatch='..'),
        Patch(facecolor='#f39c12', label='Our Method (Efficient)', hatch='//'),
        Patch(facecolor='#e74c3c', label='High Parameter Count', hatch='xx')
    ]
    ax.legend(handles=legend_elements, loc='lower right')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_tradeoff_analysis(data, save_path='robustness_vs_accuracy_tradeoff.png'):
    """Plot robustness vs accuracy trade-off"""
    fig, ax = plt.subplots(figsize=(10, 8))

    colors = []
    sizes = []
    markers = []
    for i, model in enumerate(data['models']):
        if 'Enhanced_HAT-D' in model:
            colors.append('#f39c12')
            sizes.append(120)
            markers.append('o')  # Circle
        elif 'Adversarial_Training' in model:
            colors.append('#2ecc71')
            sizes.append(100)
            markers.append('s')  # Square
        elif 'Ensemble_Defense' in model:
            colors.append('#e74c3c')
            sizes.append(100)
            markers.append('^')  # Triangle
        else:
            colors.append('#3498db')
            sizes.append(80)
            markers.append('D')  # Diamond

    for i, (x_val, y_val) in enumerate(zip(data['clean_acc'], data['avg_robust_acc'])):
        ax.scatter(x_val, y_val, c=colors[i], s=sizes[i], marker=markers[i],
                   edgecolors='black', linewidth=1, alpha=0.7)

    min_acc = min(min(data['clean_acc']), min(data['avg_robust_acc'])) - 0.02
    max_acc = max(max(data['clean_acc']), max(data['avg_robust_acc'])) + 0.01
    ax.set_xlim(min_acc, max_acc)
    ax.set_ylim(min_acc, max_acc)
    ax.plot([min_acc, max_acc], [min_acc, max_acc], 'k--', alpha=0.5, linewidth=1, label='Perfect Robustness Line')

    ax.set_xlabel('Clean Accuracy', fontsize=12, fontweight='bold')
    ax.set_ylabel('Average Robust Accuracy', fontsize=12, fontweight='bold')
    ax.set_title('Robustness vs. Accuracy Trade-off', fontsize=14, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3)

    # Legend with marker shapes
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Enhanced HAT-D (Ours)',
               markerfacecolor='#f39c12', markersize=10, markeredgecolor='black'),
        Line2D([0], [0], marker='s', color='w', label='Adversarial Training',
               markerfacecolor='#2ecc71', markersize=8, markeredgecolor='black'),
        Line2D([0], [0], marker='^', color='w', label='Ensemble Defense',
               markerfacecolor='#e74c3c', markersize=8, markeredgecolor='black'),
        Line2D([0], [0], marker='D', color='w', label='Baseline Methods',
               markerfacecolor='#3498db', markersize=6, markeredgecolor='black'),
        Line2D([0], [0], linestyle='--', color='k', label='Perfect Robustness Line')
    ]

    ax.legend(handles=legend_elements, loc='lower right')
    ax.set_xlim(0.75, 0.9)
    ax.set_ylim(0.75, 0.9)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_performance_summary(data, save_path='performance_summary_table.png'):
    """Create performance summary table"""
    fig, ax = plt.subplots(figsize=(14, 8))
    ax.axis('tight')
    ax.axis('off')

    table_data = []
    for i, model in enumerate(data['models']):
        if model == 'Enhanced_HAT-D (Ours)':
            status = 'OURS'
        elif model == 'Adversarial_Training':
            status = 'BEST'
        elif model == 'Ensemble_Defense':
            status = 'FIXED'
        else:
            status = 'OK'

        table_data.append([
            model.replace('_', ' '),
            f"{data['clean_acc'][i]:.3f}",
            f"{data['avg_robust_acc'][i]:.3f}",
            f"{data['robustness_drop'][i]:.3f}",
            status
        ])

    table = ax.table(cellText=table_data,
                    colLabels=['Model', 'Clean Acc.', 'Robust Acc.', 'Robustness Drop', 'Status'],
                    cellLoc='center',
                    loc='center',
                    bbox=[0, 0, 1, 1])

    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 2)

    for i in range(len(data['models'])):
        if table_data[i][4] == 'OURS':
            table[(i+1, 4)].set_facecolor('#f39c12')
            table[(i+1, 4)].set_text_props(weight='bold', color='white')
        elif table_data[i][4] == 'BEST':
            table[(i+1, 4)].set_facecolor('#2ecc71')
            table[(i+1, 4)].set_text_props(weight='bold', color='white')
        elif table_data[i][4] == 'FIXED':
            table[(i+1, 4)].set_facecolor('#e74c3c')
            table[(i+1, 4)].set_text_props(weight='bold', color='white')
        else:
            table[(i+1, 4)].set_facecolor('#3498db')
            table[(i+1, 4)].set_text_props(weight='bold', color='white')

    for j in range(5):
        table[(0, j)].set_facecolor('#34495e')
        table[(0, j)].set_text_props(weight='bold', color='white')

    plt.title('Performance Summary', fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# ---- Adapter: training-history -> plot_individual_model_curves ----------------

def plot_individual_model_curves(model_name, data, save_individual=True):
    """Plot training curves for a single model"""
    epochs = range(1, data['epochs'] + 1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot Loss curves
    ax1.plot(epochs, data['train_losses'], 'b-', label='Training Loss',
             linewidth=2.5, marker='o', markersize=6)
    ax1.plot(epochs, data['val_losses'], 'r-', label='Validation Loss',
             linewidth=2.5, marker='s', markersize=6)
    ax1.set_title(f'{model_name} - Training & Validation Loss',
                  fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(1, data['epochs'])

    # Plot Accuracy curves
    ax2.plot(epochs, data['train_accs'], 'b-', label='Training Accuracy',
             linewidth=2.5, marker='o', markersize=6)
    ax2.plot(epochs, data['val_accs'], 'r-', label='Validation Accuracy',
             linewidth=2.5, marker='s', markersize=6)
    ax2.set_title(f'{model_name} - Training & Validation Accuracy',
                  fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy', fontsize=12)
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(1, data['epochs'])
    ax2.set_ylim(0.75, 1.0)  # Focus on the relevant accuracy range

    plt.tight_layout()

    if save_individual:
        filename = f"{model_name.replace(' ', '_').replace('(', '').replace(')', '')}_training_curves.png"
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Individual plot saved: {filename}")

    plt.show()

# ---- Adapter wrappers used by run_experiment ------------------------------------------------

def plot_training_history(history, run_id, save_individual=True):
    """
    Map your `history` produced by train_model_comprehensive into
    the schema expected by plot_individual_model_curves and display/save it.
    """
    # history keys used earlier: 'train_loss','train_acc','val_loss','val_acc','total_epochs'
    epochs = history.get('total_epochs', len(history.get('train_loss', [])))
    data = {
        'epochs': epochs,
        'train_losses': history.get('train_loss', [])[:epochs],
        'val_losses': history.get('val_loss', [])[:epochs],
        'train_accs': history.get('train_acc', [])[:epochs],
        'val_accs': history.get('val_acc', [])[:epochs]
    }
    # Use run_id as model name in filename/title
    plot_individual_model_curves(run_id, data, save_individual=save_individual)

def create_summary_plots(results_df, save_dir=None):
    """
    Build the `data` dict expected by the plotting functions from results_df.
    results_df is expected to have at least: model_name, clean_acc, adv_acc (or pgd_acc/fgsm_acc)
    If attack-specific columns exist, use them; otherwise fall back to adv_acc.
    """
    if save_dir is None:
        save_dir = Path.cwd()

    # Convert to DataFrame (if not already)
    if not isinstance(results_df, pd.DataFrame):
        results_df = pd.DataFrame(results_df)

    models = results_df['model_name'].tolist()
    clean_acc = results_df['clean_acc'].tolist()

    # Attempt to pick individual attack columns if present
    synonym_acc = results_df.get('synonym_acc', None)
    character_acc = results_df.get('character_acc', None)
    insertion_acc = results_df.get('insertion_acc', None)
    mixed_acc = results_df.get('mixed_acc', None)

    # Fallbacks: use adv/pgd/fgsm columns if specific ones missing
    fallback_adv = None
    for col in ['avg_robust_acc', 'adv_acc', 'pgd_acc', 'fgsm_acc']:
        if col in results_df.columns:
            fallback_adv = results_df[col].tolist()
            break

    # If specific attack lists are missing, copy fallback into all of them
    if synonym_acc is None:
        synonym_acc = fallback_adv if fallback_adv is not None else clean_acc
    if character_acc is None:
        character_acc = fallback_adv if fallback_adv is not None else clean_acc
    if insertion_acc is None:
        insertion_acc = fallback_adv if fallback_adv is not None else clean_acc
    if mixed_acc is None:
        mixed_acc = fallback_adv if fallback_adv is not None else clean_acc

    # ensure lists
    synonym_acc = list(synonym_acc)
    character_acc = list(character_acc)
    insertion_acc = list(insertion_acc)
    mixed_acc = list(mixed_acc)

    # Compute average robust accuracy if not present
    if 'avg_robust_acc' in results_df.columns:
        avg_robust_acc = results_df['avg_robust_acc'].tolist()
    else:
        avg_robust_acc = [(s+c+i+m)/4 for s,c,i,m in zip(synonym_acc, character_acc, insertion_acc, mixed_acc)]

    robustness_drop = [clean - avg for clean, avg in zip(clean_acc, avg_robust_acc)]

    # For params: try to use columns if available, otherwise set to zeros
    trainable_params = results_df.get('num_trainable_params', [0]*len(models)).tolist()
    total_params = results_df.get('total_params', [0]*len(models)).tolist()

    data = {
        'models': models,
        'clean_acc': clean_acc,
        'synonym_acc': synonym_acc,
        'character_acc': character_acc,
        'insertion_acc': insertion_acc,
        'mixed_acc': mixed_acc,
        'avg_robust_acc': avg_robust_acc,
        'robustness_drop': robustness_drop,
        'trainable_params': trainable_params,
        'total_params': total_params
    }

    # Call the plotting functions you added
    plot_individual_performance(data, save_path=str(Path(save_dir) / 'individual_performance_clean.png'))
    plot_clean_vs_robust(data, save_path=str(Path(save_dir) / 'clean_vs_robust_accuracy.png'))
    plot_robustness_drop(data, save_path=str(Path(save_dir) / 'robustness_drop_analysis.png'))
    plot_efficiency_analysis(data, save_path=str(Path(save_dir) / 'model_efficiency_analysis.png'))
    plot_tradeoff_analysis(data, save_path=str(Path(save_dir) / 'robustness_vs_accuracy_tradeoff.png'))
    plot_performance_summary(data, save_path=str(Path(save_dir) / 'performance_summary_table.png'))

def plot_individual_results(results_df, save_dir=None):
    """
    Secondary wrapper (keeps compatibility with earlier calls that used plot_individual_results).
    """
    create_summary_plots(results_df, save_dir=save_dir)

print("Analysis and visualization functions defined.")

In [None]:
# Replace your existing run_experiment cell with this one.

import json
import numpy as np
from pathlib import Path
import itertools
from statsmodels.stats.contingency_tables import mcnemar
from scipy.stats import chi2_contingency

# Ensure these globals exist: RUN_MATRIX, RESULTS_DIR, RESULTS_PATH, build_run_id_from_cfg, log_run
# Ensure train_model_comprehensive accepts `cfg` and `smoothing_penalty_fn`
# Ensure randomized_smoothing_penalty and collect_predictions_on_loader / collect_adversarial_predictions exist

def _compute_run_stats_and_save_preds(run_id, model, val_loader, attack_generator, device=device):
    """
    Compute accuracies and save per-sample predictions for the run.
    Returns stats dict (clean_acc, pgd_acc, fgsm_acc) and dict of arrays.
    """
    # Clean preds
    y_clean, preds_clean = collect_predictions_on_loader(model, val_loader, device=device)
    # Adversarial preds (PGD & FGSM)
    y_pgd, preds_pgd = collect_adversarial_predictions(model, val_loader, attack_generator, attack_type='pgd', device=device)
    y_fgsm, preds_fgsm = collect_adversarial_predictions(model, val_loader, attack_generator, attack_type='fgsm', device=device)

    # Ensure y arrays align; prefer y_clean as canonical ordering
    # Compute accuracies (percent)
    clean_acc = 100.0 * np.mean(preds_clean == y_clean)
    pgd_acc   = 100.0 * np.mean(preds_pgd == y_pgd)
    fgsm_acc  = 100.0 * np.mean(preds_fgsm == y_fgsm)

    stats = {
        "run_id": run_id,
        "clean_acc": float(clean_acc),
        "pgd_acc": float(pgd_acc),
        "fgsm_acc": float(fgsm_acc),
        "n_test": int(len(y_clean))
    }

    # Save per-sample arrays for statistical tests
    np.save(RESULTS_DIR / f"y_{run_id}.npy", y_clean)
    np.save(RESULTS_DIR / f"preds_clean_{run_id}.npy", preds_clean)
    np.save(RESULTS_DIR / f"preds_pgd_{run_id}.npy", preds_pgd)
    np.save(RESULTS_DIR / f"preds_fgsm_{run_id}.npy", preds_fgsm)

    return stats, {"y_true": y_clean, "preds_clean": preds_clean, "preds_pgd": preds_pgd, "preds_fgsm": preds_fgsm}


def contingency_table_from_preds(y_true, preds_a, preds_b):
    a_corr = (preds_a == y_true)
    b_corr = (preds_b == y_true)
    n11 = int(np.sum(np.logical_and(a_corr, b_corr)))
    n10 = int(np.sum(np.logical_and(a_corr, np.logical_not(b_corr))))
    n01 = int(np.sum(np.logical_and(np.logical_not(a_corr), b_corr)))
    n00 = int(np.sum(np.logical_and(np.logical_not(a_corr), np.logical_not(b_corr))))
    return np.array([[n11, n10], [n01, n00]]), a_corr, b_corr

def pairwise_stats(y_true, preds_a, preds_b):
    table, _, _ = contingency_table_from_preds(y_true, preds_a, preds_b)
    n_total = int(table.sum())
    # McNemar
    mcnemar_res = mcnemar(table, exact=False)
    m_stat = float(mcnemar_res.statistic)
    m_p = float(mcnemar_res.pvalue)
    # Chi-square test (no Yates correction here)
    chi2, chi2_p, dof, expected = chi2_contingency(table, correction=False)
    phi = np.sqrt(chi2 / n_total) if n_total > 0 else float('nan')
    return {
        "contingency_table": table.tolist(),
        "mcnemar_stat": m_stat, "mcnemar_p": m_p,
        "chi2": float(chi2), "chi2_p": float(chi2_p), "chi2_dof": int(dof),
        "chi2_expected": expected.tolist(),
        "phi": float(phi),
        "n": n_total
    }

def perform_pairwise_tests_for_model(run_pred_map):
    """
    run_pred_map: dict run_id -> {'y_true':..., 'preds_clean': ...}
    Returns dict of pairwise stats for all combinations.
    """
    pairwise_results = {}
    run_ids = sorted(run_pred_map.keys())
    for a_id, b_id in itertools.combinations(run_ids, 2):
        y_true = run_pred_map[a_id]['y_true']
        preds_a = run_pred_map[a_id]['preds_clean']
        preds_b = run_pred_map[b_id]['preds_clean']
        stats = pairwise_stats(y_true, preds_a, preds_b)
        pairwise_results[f"{a_id}__vs__{b_id}"] = stats
    return pairwise_results


In [None]:
import numpy as np

def convert_np_types(obj):
    if isinstance(obj, dict):
        return {k: convert_np_types(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_np_types(i) for i in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    else:
        return obj

In [None]:
!ls -lh plots | head
!ls -lh results | head

In [None]:
def run_experiment_ablation():
    """
    Runs the 4-run mini-ablation for each model in MODEL_REGISTRY.
    Saves per-run predictions to results/ and computes pairwise stats (McNemar + Chi2).
    Returns a nested dictionary of results_all_models[model_name][run_id] = stats
    """
    results_all_models = {}

    for model_name, model_class in MODEL_REGISTRY.items():
        print(f"\n--- Starting experiments for {model_name} (ablation matrix) ---")
        model_results = {}         # store per-run summary stats
        run_preds_for_stats = {}   # store per-run per-sample preds for statistical tests

        for cfg in RUN_MATRIX:
            run_id = build_run_id_from_cfg(cfg)
            print(f"\n>>> Running {run_id} for model {model_name}")

            # instantiate fresh model for each run
            model = model_class()
            model.to(device)

            # attack generator depends on model instance
            attack_generator = AdversarialAttackGenerator(model, tokenizer, device)

            # Determine original use_adversarial boolean (keep old behaviour but allow cfg override)
            use_adv_flag = (model_name == 'Adversarial_Training')

            # Train: pass cfg and smoothing_penalty_fn (randomized smoothing)
            trained_model, history = train_model_comprehensive(
                model, run_id, train_loader, val_loader,
                epochs=max_epochs,
                use_adversarial=use_adv_flag,
                attack_generator=attack_generator,
                augmentation_ratio=0.65,
                attack_type='mixed',
                cfg=cfg,
                smoothing_penalty_fn=randomized_smoothing_penalty  # will fallback internally if unsupported
            )

            # Plot training history (existing util)
            try:
                plot_training_history(history, run_id)
            except Exception:
                print("plot_training_history failed or not available; continuing.")

            # Evaluate robustness (user utility) — keep for convenience but also collect per-sample preds
            try:
                robustness_results = evaluate_model_robustness(trained_model, val_loader, attack_generator, run_id)
            except Exception:
                robustness_results = {}

            # Compute & save per-sample stats and arrays
            stats, preds_dict = _compute_run_stats_and_save_preds(run_id, trained_model, val_loader, attack_generator, device=device)

            # Merge robustness_results into stats if available
            stats.update(robustness_results if isinstance(robustness_results, dict) else {})

            # Add model meta
            stats.update({
                'model_size_MB': trained_model.get_model_size() if hasattr(trained_model, 'get_model_size') else None,
                'num_trainable_params': trained_model.count_parameters() if hasattr(trained_model, 'count_parameters') else None,
                'history': history
            })
            # === Inside the run_experiment_ablation function, within the `for cfg in RUN_MATRIX:` loop ===

            # ... (after stats dictionary is created) ...

            # --- REPLACEMENT FOR LOGGING ---
            # Remove the old JSON logger: log_run(run_id, stats)

            # Create a clean, summarized version of stats for CSV logging
            # We exclude the full epoch-by-epoch history to keep the CSV clean
            csv_stats = stats.copy()
            if 'history' in csv_stats:
                # Keep only the summary metrics from history
                history_summary = {
                    'best_val_acc': csv_stats['history'].get('best_val_acc'),
                    'total_epochs': csv_stats['history'].get('total_epochs')
                }
                csv_stats['history'] = history_summary

            # Use the new CSV logger
            log_run_csv(run_id, csv_stats)
            # --- END OF REPLACEMENT ---

            # store in memory (this part is unchanged)
            model_results[run_id] = stats
            run_preds_for_stats[run_id] = {'y_true': preds_dict['y_true'], 'preds_clean': preds_dict['preds_clean']}
            # log to aggregate.json
            log_run_csv(run_id, csv_stats)


        # After all RUN_MATRIX runs for this model: compute pairwise stats (BASE vs others)
        pairwise_results = perform_pairwise_tests_for_model(run_preds_for_stats)

        # print succinct BASE comparisons
        print(f"\nPairwise stats for model {model_name} (showing BASE comparisons):")
        for key, val in pairwise_results.items():
            if key.startswith("BASE__vs__") or key.endswith("__vs__BASE"):
                print(key)
                print("  contingency:", val["contingency_table"])
                print(f"  McNemar: stat={val['mcnemar_stat']:.3f}, p={val['mcnemar_p']:.4f}")
                print(f"  Chi2:    χ2={val['chi2']:.3f}, p={val['chi2_p']:.4f}, phi={val['phi']:.4f}, n={val['n']}")
                print("")

        # # Save pairwise stats into aggregate.json under model name
        # # Load aggregate, inject pairwise under a keyed name and save back
        # try:
        #     aggregate = json.load(open(RESULTS_PATH, "r"))
        # except Exception:
        #     aggregate = {}
        # aggregate_key = f"{model_name}__pairwise_stats"
        # aggregate[aggregate_key] = pairwise_results
        # aggregate_clean = convert_np_types(aggregate)  # <-- convert here
        # json.dump(aggregate_clean, open(RESULTS_PATH, "w"), indent=2)
        # print(f"Saved pairwise stats to {RESULTS_PATH} under key {aggregate_key}")

        results_all_models[model_name] = model_results


        # Build a summary DataFrame across models using the BASE run for each model (fall back to first run)
        summary_rows = []
        for model_name, model_runs in results_all_models.items():
            # find the BASE configuration run (cfg with no_adv=False and no_smooth=False)
            base_stats = None
            for run_id, st in model_runs.items():
                cfg = st.get('cfg', {})
                if not cfg.get('no_adv', False) and not cfg.get('no_smooth', False):
                    base_stats = st
                    break
            # fallback: if no explicit BASE found, take the first recorded run
            if base_stats is None and len(model_runs) > 0:
                base_stats = next(iter(model_runs.values()))

            if base_stats is None:
                # nothing recorded for this model (shouldn't happen) -- skip
                continue

            # normalize field names (handle possible name variants)
            clean_acc = base_stats.get('clean_acc') or base_stats.get('clean_accuracy') or base_stats.get('clean')
            pgd_acc = base_stats.get('pgd_acc') or base_stats.get('adv_acc') or base_stats.get('avg_robust_acc')
            fgsm_acc = base_stats.get('fgsm_acc') or base_stats.get('fgsm')

            summary_rows.append({
                'model_name': model_name,
                'clean_acc': clean_acc,
                'pgd_acc': pgd_acc,
                'fgsm_acc': fgsm_acc,
                'num_trainable_params': base_stats.get('num_trainable_params'),
                'total_params': base_stats.get('total_params')
            })

        # Create DataFrame and save
        results_df = pd.DataFrame(summary_rows)
        results_df.to_csv(RESULTS_DIR / "summary_results.csv", index=False)
        summary_clean = convert_np_types(results_df.to_dict(orient='records'))  # <-- convert here
        json.dump(summary_clean, open(RESULTS_DIR / "summary_results.json", "w"), indent=2)
        print(f"Saved summary table to {RESULTS_DIR / 'summary_results.csv'} and JSON")

        # Generate the summary plots using your plotting wrappers
        # try:
        #     create_summary_plots(results_df, save_dir=RESULTS_DIR)
        #     plot_individual_results(results_df, save_dir=RESULTS_DIR)
        #     print("Summary plots generated and saved to", RESULTS_DIR)
        # except Exception as e:
        #     print("Failed to generate summary plots:", e)

        # ----------------- End insertion -------------------------------------------------------------------------

    print("\nAll experiments completed.")
    return results_all_models

# Run the ablation experiment
results = run_experiment_ablation()

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def generate_dissertation_artifacts(results_csv="results/all_runs_results.csv"):
    df = pd.read_csv(results_csv)

    # --- 1. Pareto Plot (Clean vs Mixed Robustness) ---
    plt.figure(figsize=(10, 6))
    sns.scatterplot(data=df, x='Clean_Acc', y='Mixed_Acc', hue='Model_Name', style='Config_Tag', s=100)
    plt.title("Robustness-Accuracy Trade-off (Pareto Frontier)")
    plt.xlabel("Clean Data Accuracy")
    plt.ylabel("Mixed Attack Accuracy")
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.savefig("results/pareto_frontier.png")

    # --- 2. Attack Success Rate Table ---
    # ASR = (Clean Acc - Attack Acc) / Clean Acc
    df['Synonym_ASR'] = (df['Clean_Acc'] - df['Synonym_Acc']) / df['Clean_Acc']
    df['Char_ASR'] = (df['Clean_Acc'] - df['Character_Acc']) / df['Clean_Acc']

    asr_table = df[['Model_Name', 'Config_Tag', 'Synonym_ASR', 'Char_ASR']]
    asr_table.to_csv("results/attack_success_rates.csv", index=False)

    # --- 3. Efficiency Summary ---
    efficiency_df = df[['Model_Name', 'Config_Tag', 'Avg_Latency', 'Mixed_Acc']]
    print("\n📊 Efficiency vs. Robustness Summary:")
    print(efficiency_df)

    print("\n✅ Artifacts generated in /results folder.")