## 1. Setup & GPU Detection üîß

In [None]:
# Install required libraries
!pip install -q transformers datasets
print("‚úÖ transformers and datasets installed.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertForMaskedLM, BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import time
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# FORCE GPU if available
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU ENABLED: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
    # Clear cache
    torch.cuda.empty_cache()
else:
    device = torch.device('cpu')
    print("WARNING: Running on CPU - Training will be VERY slow!")
    print("   Consider restarting kernel and selecting a CUDA-enabled Python environment")

print(f"\nDevice set to: {device}")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")

# Hyperparameters (GPU-optimized)
BATCH_SIZE = 16 if torch.cuda.is_available() else 8
NUM_EPOCHS = 5 if torch.cuda.is_available() else 2
LEARNING_RATE = 3e-4
TEMPERATURE = 0.05
ADAPTER_SIZE = 256
TRAIN_SAMPLE = 5  # Augmentation factor

print(f"\nHyperparameters:")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Temperature: {TEMPERATURE}")

# Load tokenizer for global use
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print(f"\nTokenizer loaded: bert-base-uncased")


## 2. Helper Functions & Classes üõ†Ô∏è

This section contains all the core functionality:
- **Contrastive Loss Functions** (InfoNCE, Triplet, Margin)
- **Adapter Modules** (Lightweight debiasing layers)
- **Evaluation Metrics** (CrowS-Pairs, PLL scoring)
- **Built-in Bias Examples** (31 carefully curated pairs)

In [None]:
# # ============================================================================
# # CONTRASTIVE LOSS FUNCTIONS
# # ============================================================================

# class InfoNCELoss(nn.Module):
#     """InfoNCE contrastive loss for pushing biased representations apart"""
#     def __init__(self, temperature=0.07):
#         super().__init__()
#         self.temperature = temperature

#     def forward(self, anchor, positive):
#         """
#         Args:
#             anchor: [batch_size, hidden_dim] - stereotypical embeddings
#             positive: [batch_size, hidden_dim] - anti-stereotypical embeddings
#         """
#         # Normalize
#         anchor = F.normalize(anchor, dim=1)
#         positive = F.normalize(positive, dim=1)

#         # Compute similarity matrix
#         logits = torch.matmul(anchor, positive.T) / self.temperature

#         # Labels: we want to maximize distance, so inverse of similarity
#         batch_size = anchor.size(0)
#         labels = torch.arange(batch_size).to(anchor.device)

#         # Cross-entropy loss (higher similarity to different = better debiasing)
#         loss = F.cross_entropy(logits, labels)
#         return loss


# ============================================================================
# ADAPTER MODULES
# ============================================================================

class DebiasAdapter(nn.Module):
    """Lightweight adapter for bias mitigation"""
    def __init__(self, hidden_size=768, adapter_size=256):
        super().__init__()
        self.adapter = nn.Sequential(
            nn.Linear(hidden_size, adapter_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(adapter_size, hidden_size),
            nn.Dropout(0.1)
        )

        # Initialize with small weights for stability
        for module in self.adapter:
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, std=0.01)
                nn.init.zeros_(module.bias)

    def forward(self, hidden_states):
        # Residual connection
        return hidden_states + self.adapter(hidden_states)


class AdaptedBERT(nn.Module):
    """BERT with debiasing adapter"""
    def __init__(self, base_model, adapter_size=256):
        super().__init__()
        self.bert = base_model.bert
        self.adapter = DebiasAdapter(base_model.config.hidden_size, adapter_size)
        self.config = base_model.config

        # Freeze BERT, only train adapter
        for param in self.bert.parameters():
            param.requires_grad = False

        for param in self.adapter.parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state
        adapted = self.adapter(last_hidden)
        return adapted[:, 0, :]  # Return [CLS] token

    def get_full_model(self):
        """Return the full model for evaluation"""
        return self.bert.base_model if hasattr(self.bert, 'base_model') else self.bert


print("‚úÖ Loss functions and adapter classes defined")

In [None]:
from collections import defaultdict
from tqdm.auto import tqdm # Make sure tqdm is imported here
import torch.nn.functional as F

# ============================================================================
# EVALUATION FUNCTIONS (UPDATED)
# ============================================================================

def compute_pll_score(model, tokenizer, text, device, max_length=128):
    """Compute Pseudo-Log-Likelihood score for bias evaluation"""
    model.eval()

    inputs = tokenizer(text, return_tensors="pt", truncation=True,
                      max_length=max_length, padding=False)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(device)

    seq_len = input_ids.shape[1]
    total_log_prob = 0.0

    # Handle cases where text is too short or empty
    if seq_len <= 2:
        return 0.0

    with torch.no_grad():
        # Mask each token and predict it
        for i in range(1, seq_len - 1):  # Skip [CLS] and [SEP]
            masked_input = input_ids.clone()
            original_token = masked_input[0, i].item()
            masked_input[0, i] = tokenizer.mask_token_id

            outputs = model(input_ids=masked_input, attention_mask=attention_mask)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs

            log_probs = F.log_softmax(logits[0, i, :], dim=0)
            total_log_prob += log_probs[original_token].item()

    # Return average log-likelihood
    return total_log_prob / (seq_len - 2)


def evaluate_on_crows_pairs(model, tokenizer, pairs, device):
    """Evaluate model on CrowS-Pairs (Handles both key formats)"""
    stereo_wins = 0
    by_type = defaultdict(lambda: {'stereo_wins': 0, 'total': 0})

    for pair in tqdm(pairs, desc="Evaluating"):

        # --- THIS IS THE FIX ---
        # Get stereo sentence (check for 'stereo' key, fallback to 'stereotype')
        stereo_sent = pair.get('stereo', pair.get('stereotype'))

        # Get anti-stereo sentence (check for 'anti' key, fallback to 'anti_stereotype')
        anti_sent = pair.get('anti', pair.get('anti_stereotype'))
        # ---------------------

        if not stereo_sent or not anti_sent:
            continue # Skip if sentences are missing

        pll_stereo = compute_pll_score(model, tokenizer, stereo_sent, device)
        pll_anti = compute_pll_score(model, tokenizer, anti_sent, device)

        bias_type = pair.get('bias_type', 'unknown')
        by_type[bias_type]['total'] += 1

        if pll_stereo > pll_anti:
            stereo_wins += 1
            by_type[bias_type]['stereo_wins'] += 1

    preference = stereo_wins / len(pairs) if pairs else 0.5

    type_preferences = {}
    for bias_type, scores in by_type.items():
        if scores['total'] > 0:
            type_preferences[bias_type] = {
                'preference': scores['stereo_wins'] / scores['total'],
                'total': scores['total']
            }

    return {
        'preference': preference,
        'stereo_wins': stereo_wins,
        'total': len(pairs),
        'by_type': type_preferences
    }

print("‚úÖ Evaluation functions defined (now robust to both data formats)")

## 12. Enhanced Strategy: Contrastive PLL Loss (Target: 50%) üéØ

The old successful notebook used a different loss function that directly optimizes for **neutrality (50% bias)**.

**Key improvements:**
1. **Squared difference loss**: `(pll_stereo - pll_anti)¬≤` ‚Üí pushes both PLLs to be equal
2. **Larger CrowS-Pairs dataset**: Load full HuggingFace dataset (1508 pairs vs our 31)
3. **Per-pair gradient**: Train on individual pairs instead of batched embeddings

This should bring bias from 77% ‚Üí 50% (neutral) with sufficient epochs.

In [None]:
print("="*60)
print("üì• DOWNLOADING FULL CROWS-PAIRS DATASET")
print("="*60)

import pandas as pd
import urllib.request
import io

try:
    # Download CrowS-Pairs directly from GitHub
    url = "https://raw.githubusercontent.com/nyu-mll/crows-pairs/master/data/crows_pairs_anonymized.csv"
    print(f"üì° Downloading from: {url}")

    with urllib.request.urlopen(url) as response:
        csv_data = response.read().decode('utf-8')

    df = pd.read_csv(io.StringIO(csv_data))
    print(f"‚úÖ Successfully loaded {len(df)} bias examples from CrowS-Pairs!")

    # Convert to our format
    crows_full = []
    for _, row in df.iterrows():
        crows_full.append({
            'stereotype': row['sent_more'],
            'anti_stereotype': row['sent_less'],
            'bias_type': row['bias_type']
        })

    # Split into train/eval (80/20)
    split_idx = int(len(crows_full) * 0.8)
    crows_train = crows_full[:split_idx]
    crows_eval = crows_full[split_idx:]

    print(f"üìä Training examples: {len(crows_train)}")
    print(f"üìä Evaluation examples: {len(crows_eval)}")
    print(f"üìä Bias types present: {df['bias_type'].unique().tolist()}")
    print("="*60)

except Exception as e:
    print(f"‚ö†Ô∏è  Failed to download CrowS-Pairs: {e}")
    print(f"   Falling back to built-in 31 pairs")

    # Fallback to built-in pairs
    split_idx = int(len(bias_pairs) * 0.8)
    crows_train = bias_pairs[:split_idx]
    crows_eval = bias_pairs[split_idx:]
    print(f"üìä Training examples: {len(crows_train)}")
    print(f"üìä Evaluation examples: {len(crows_eval)}")
    print("="*60)


In [None]:
class ContrastivePLLTrainer:
    """Trains model to minimize (PLL_stereo - PLL_anti)^2 to achieve neutral 50% bias"""

    def __init__(self, model, tokenizer, device, learning_rate=1e-5):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    def pll_score_with_grad(self, model, text):
        """Compute pseudo-log-likelihood score (with gradients enabled)"""
        tokens = self.tokenizer.encode(text, add_special_tokens=True)
        if len(tokens) <= 2:  # Only [CLS] and [SEP]
            return torch.tensor(0.0, device=self.device)

        total_pll = torch.tensor(0.0, device=self.device, requires_grad=True)
        count = 0

        for i in range(1, len(tokens) - 1):  # Skip [CLS] and [SEP]
            masked_tokens = tokens.copy()
            original_token = masked_tokens[i]
            masked_tokens[i] = self.tokenizer.mask_token_id

            input_ids = torch.tensor([masked_tokens], device=self.device)

            with torch.set_grad_enabled(True):
                outputs = model(input_ids)
                logits = outputs.logits
                log_probs = torch.log_softmax(logits[0, i], dim=-1)
                token_pll = log_probs[original_token]
                total_pll = total_pll + token_pll
                count += 1

        return total_pll / count if count > 0 else torch.tensor(0.0, device=self.device)

    def train_epoch(self, pairs):
        self.model.train()
        epoch_losses = []

        for pair in tqdm(pairs, desc="Contrastive PLL Training"):
            self.optimizer.zero_grad()

            # Compute PLL for both sentences
            pll_stereo = self.pll_score_with_grad(self.model, pair['stereotype'])
            pll_anti = self.pll_score_with_grad(self.model, pair['anti_stereotype'])

            # Loss: push both PLLs to be equal (difference should be 0)
            loss = (pll_stereo - pll_anti) ** 2

            loss.backward()
            self.optimizer.step()

            epoch_losses.append(loss.item())

        return np.mean(epoch_losses)

print("‚úÖ Contrastive PLL trainer defined")


In [None]:
# print("="*60)
# print("AGGRESSIVE CONTRASTIVE PLL TRAINING")
# print("="*60)

# # Force GPU usage
# if torch.cuda.is_available():
#     device = torch.device('cuda')
#     print(f"GPU Detected: {torch.cuda.get_device_name(0)}")
#     print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# else:
#     device = torch.device('cpu')
#     print("WARNING: Running on CPU - this will be slow!")

# # Create fresh model with adapter
# contrastive_base = BertForMaskedLM.from_pretrained('bert-base-uncased').to(device)
# contrastive_adapter = DebiasAdapter(768, ADAPTER_SIZE).to(device)

# class AdaptedMLM(nn.Module):
#     def __init__(self, base, adapter):
#         super().__init__()
#         self.base = base
#         self.adapter = adapter
#         # Freeze base model
#         for param in self.base.parameters():
#             param.requires_grad = False

#     def forward(self, input_ids):
#         outputs = self.base.bert(input_ids)
#         hidden = outputs.last_hidden_state
#         adapted = self.adapter(hidden)
#         logits = self.base.cls(adapted)
#         return type('Outputs', (), {'logits': logits})()

# contrastive_model = AdaptedMLM(contrastive_base, contrastive_adapter).to(device)

# # Use aggressive training schedule
# CONTRASTIVE_EPOCHS = 5
# CONTRASTIVE_LR = 5e-5     # Higher learning rate for faster convergence

# print(f"Configuration:")
# print(f"   Epochs: {CONTRASTIVE_EPOCHS}")
# print(f"   Learning Rate: {CONTRASTIVE_LR}")
# print(f"   Training pairs: {len(crows_train)}")
# print(f"   Evaluation pairs: {len(crows_eval)}")
# print(f"   Device: {device}")

# # Train
# trainer = ContrastivePLLTrainer(contrastive_model, tokenizer, device, learning_rate=CONTRASTIVE_LR)
# contrastive_losses = []

# t0 = time.time()
# for epoch in range(CONTRASTIVE_EPOCHS):
#     avg_loss = trainer.train_epoch(crows_train)
#     contrastive_losses.append(avg_loss)
#     print(f"Epoch {epoch+1}/{CONTRASTIVE_EPOCHS}: Avg Loss = {avg_loss:.4f}")

# training_time = time.time() - t0
# print(f"Training complete in {training_time:.1f}s")
# print(f"   Loss: {contrastive_losses[0]:.4f} -> {contrastive_losses[-1]:.4f}")
# print("="*60)


In [None]:
print("="*60)
print("AGGRESSIVE CONTRASTIVE PLL TRAINING")
print("="*60)

# Force GPU usage
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU Detected: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    device = torch.device('cpu')
    print("WARNING: Running on CPU - this will be slow!")

# Create fresh model with adapter
contrastive_base = BertForMaskedLM.from_pretrained('bert-base-uncased').to(device)
contrastive_adapter = DebiasAdapter(768, ADAPTER_SIZE).to(device)


# --- THIS CLASS IS NOW CORRECTED ---
class AdaptedMLM(nn.Module):
    def __init__(self, base, adapter):
        super().__init__()
        self.base = base
        self.adapter = adapter
        # Freeze base model
        for param in self.base.parameters():
            param.requires_grad = False

    # --- FIX: Added 'attention_mask=None' and passed it to the model ---
    def forward(self, input_ids, attention_mask=None):
        outputs = self.base.bert(input_ids, attention_mask=attention_mask)
        # ----------------------------------------------------------------

        hidden = outputs.last_hidden_state
        adapted = self.adapter(hidden)
        logits = self.base.cls(adapted)
        return type('Outputs', (), {'logits': logits})()
# ---------------------------------

contrastive_model = AdaptedMLM(contrastive_base, contrastive_adapter).to(device)

# Use aggressive training schedule
CONTRASTIVE_EPOCHS = 5
CONTRASTIVE_LR = 5e-5      # Higher learning rate for faster convergence

print(f"Configuration:")
print(f"   Epochs: {CONTRASTIVE_EPOCHS}")
print(f"   Learning Rate: {CONTRASTIVE_LR}")
print(f"   Training pairs: {len(crows_train)}")
print(f"   Evaluation pairs: {len(crows_eval)}")
print(f"   Device: {device}")

# Train
trainer = ContrastivePLLTrainer(contrastive_model, tokenizer, device, learning_rate=CONTRASTIVE_LR)
contrastive_losses = []

t0 = time.time()
for epoch in range(CONTRASTIVE_EPOCHS):
    avg_loss = trainer.train_epoch(crows_train)
    contrastive_losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{CONTRASTIVE_EPOCHS}: Avg Loss = {avg_loss:.4f}")

training_time = time.time() - t0
print(f"Training complete in {training_time:.1f}s")
if contrastive_losses: # Avoid error if epochs = 0
    print(f"   Loss: {contrastive_losses[0]:.4f} -> {contrastive_losses[-1]:.4f}")
print("="*60)

In [None]:
# Quality evaluation: perplexity on neutral sentences
neutral_sentences = [
    "The weather is nice today.",
    "I enjoy reading books in my free time.",
    "The meeting is scheduled for tomorrow morning.",
    "She bought groceries from the store.",
    "Technology has changed how we communicate.",
    "The restaurant serves delicious food.",
    "He enjoys playing sports on weekends.",
    "The movie was entertaining and well-made.",
    "They traveled to several countries last year.",
    "Coffee helps me stay awake in the morning.",
]

def compute_perplexity(model, tokenizer, sentences, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=tokenizer.pad_token_id)

    with torch.no_grad():
        for sent in sentences:
            inputs = tokenizer(sent, return_tensors='pt', truncation=True, max_length=64)
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids)).to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs

            # Shift for autoregressive loss
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()

            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss += loss.item()
            total_tokens += (shift_labels != tokenizer.pad_token_id).sum().item()

    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
    perplexity = np.exp(avg_loss)
    return perplexity

print("‚úÖ Perplexity function and neutral sentences defined.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertForMaskedLM
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import time
from typing import List, Tuple, Dict
import warnings
from collections import defaultdict # Added for eval function
import torch.optim as optim # Added for PLL trainer
from torch.utils.data import Dataset, DataLoader # Added for PLL trainer
import random # Added for PLL trainer
import urllib.request # Added for data download
import io # Added for data download

warnings.filterwarnings('ignore')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# FORCE GPU if available
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ Device set to: {device}")
    torch.cuda.empty_cache()
else:
    device = torch.device('cpu')
    print("WARNING: Running on CPU - this will be slow!")

# Hyperparameters (GPU-optimized)
BATCH_SIZE = 16
NUM_EPOCHS = 5
LEARNING_RATE = 3e-4
TEMPERATURE = 0.05
ADAPTER_SIZE = 256
TRAIN_SAMPLE = 5

# --- THIS IS THE FIX FOR 'comp_df' not defined ---
comp_df = pd.DataFrame()
trained_models = {}
# ---------------------------------------------

# Load tokenizer for global use
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print(f"\nTokenizer loaded: bert-base-uncased")

In [None]:
# Evaluate contrastive model
print("\n" + "="*60)
print("üìä EVALUATING CONTRASTIVE MODEL")
print("="*60)

trainer.model.eval()

# --- MODIFIED FOR COMPARABLE RESULTS ---
# We are now evaluating on the *full* 'crows_eval' set, just like the baseline.
print(f"\n‚è≥ Evaluating trained model on all {len(crows_eval)} evaluation pairs...")
contrastive_results = evaluate_on_crows_pairs(trainer.model, tokenizer, crows_eval, device)
# -------------------------------------

print("‚è≥ Computing perplexity on neutral sentences...")
contrastive_ppl = compute_perplexity(trainer.model, tokenizer, neutral_sentences, device)

contrastive_overall = contrastive_results['preference'] * 100.0

print(f"\nüéØ Contrastive PLL Results:")
print(f"   Overall Bias: {contrastive_overall:.2f}%")
print(f"   Target: 50.0% (neutral)")
print(f"   Distance from target: {abs(contrastive_overall - 50.0):.2f}pp")
print(f"   Perplexity: {contrastive_ppl:.2f}")

print(f"\nüìã By Category:")
for bias_type, scores in sorted(contrastive_results['by_type'].items()):
    pref = scores['preference'] * 100.0
    total = scores['total']
    indicator = "üü¢" if pref < 60 else "üü°" if pref < 70 else "üî¥"
    print(f"   {indicator} {bias_type.capitalize()}: {pref:.1f}% ({total} pairs)")

# Add to comparison
contrastive_row = {
    'name': 'contrastive_pll',
    'type': 'contrastive',
    'trainable_params': sum(p.numel() for p in trainer.model.parameters() if p.requires_grad),
    'overall_pref': contrastive_overall,
    'perplexity': contrastive_ppl,
    'time_sec': 0  # Not tracked in detail
}
for k,v in contrastive_results['by_type'].items():
    contrastive_row[f"cat_{k}"] = v['preference'] * 100.0

comp_df = pd.concat([comp_df, pd.DataFrame([contrastive_row])], ignore_index=True)
trained_models['contrastive_pll'] = trainer.model

print("\n‚úÖ Added to comparison table")
print("="*60)

In [None]:
torch.save(contrastive_adapter.state_dict(), "/content/saved_models/adapter_only.pt")


## BASELINE BERT + crowspair

In [None]:
print("="*60)
print("üìä RE-EVALUATING BASELINE on CrowS-Pairs Eval Set")
print("="*60)

# 'base_model' was loaded in Cell 11
# 'crows_eval' was loaded in Cell 24

base_model.to(device)
base_model.eval()

# Evaluate on the *same* 302-pair set used for the final experiment
crows_baseline_results = evaluate_on_crows_pairs(base_model, tokenizer, crows_eval, device)

crows_baseline_pref = crows_baseline_results['preference'] * 100
print(f"\nüéØ True Baseline Stereotype Preference (on CrowS-Pairs Eval): {crows_baseline_pref:.2f}%")
print("="*60)

# expt 2 : downstream with jz adapter 1 (failed)

In [None]:
# --- Step 1: Install Requirements (Run this first if needed) ---

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm
import os

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- 2. Load Real BiasBios Dataset ---
class RealBiasBiosDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', max_samples=None):
        try:
            # Load from Hugging Face
            dataset = load_dataset("LabHC/bias_in_bios", split=split)
        except Exception as e:
            print(f"Error loading dataset: {e}. Try '!pip install datasets'")
            return

        # Subsample for speed (optional)
        if max_samples is not None and max_samples < len(dataset):
            dataset = dataset.select(range(max_samples))

        self.dataset = dataset
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Profession ID Mapping
        self.label_map = {
            0: 'accountant', 1: 'architect', 2: 'attorney', 3: 'chiropractor',
            4: 'comedian', 5: 'composer', 6: 'dentist', 7: 'dietitian',
            8: 'dj', 9: 'filmmaker', 10: 'interior_designer', 11: 'journalist',
            12: 'model', 13: 'nurse', 14: 'painter', 15: 'paralegal',
            16: 'pastor', 17: 'personal_trainer', 18: 'photographer', 19: 'physician',
            20: 'poet', 21: 'professor', 22: 'psychologist', 23: 'rapper',
            24: 'software_engineer', 25: 'surgeon', 26: 'teacher', 27: 'yoga_teacher'
        }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        text = item['hard_text']
        label = item['profession']
        gender = item['gender'] # 0=Male, 1=Female

        inputs = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=128)

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'label': torch.tensor(label),
            'gender': torch.tensor(gender)
        }

# --- 3. Model Definition ---
class DebiasAdapter(nn.Module):
    def __init__(self, hidden_size=768, adapter_size=256):
        super().__init__()
        self.adapter = nn.Sequential(
            nn.Linear(hidden_size, adapter_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(adapter_size, hidden_size),
            nn.Dropout(0.1)
        )
    def forward(self, hidden_states):
        return hidden_states + self.adapter(hidden_states)

class AdaptedBERT(nn.Module):
    def __init__(self, base_model_name='bert-base-uncased', adapter_size=256):
        super().__init__()
        self.bert = BertModel.from_pretrained(base_model_name)
        self.adapter = DebiasAdapter(768, adapter_size)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state
        adapted = self.adapter(last_hidden)
        return adapted[:, 0, :]

class BiasClassifier(nn.Module):
    def __init__(self, backbone, num_labels=28, freeze_backbone=True):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(768, num_labels)

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        embeddings = self.backbone(input_ids, attention_mask)
        return self.classifier(embeddings)

# --- 4. Metrics & Training ---
def calculate_metrics(y_true, y_pred, genders, label_map):
    acc = accuracy_score(y_true, y_pred)

    classes = np.unique(y_true)
    gaps = []

    for cls in classes:
        m_mask = (genders == 0) & (y_true == cls)
        f_mask = (genders == 1) & (y_true == cls)

        if m_mask.sum() == 0 or f_mask.sum() == 0: continue

        m_tpr = (y_pred[m_mask] == cls).sum() / m_mask.sum()
        f_tpr = (y_pred[f_mask] == cls).sum() / f_mask.sum()

        gaps.append(abs(m_tpr - f_tpr))

    avg_tpr_gap = np.mean(gaps) if gaps else 0.0
    return acc, avg_tpr_gap

def train_and_evaluate(model_name, model, train_loader, test_loader, epochs=3):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    model.to(device)

    print(f"\nüöÄ Training {model_name}...")
    model.train()
    for epoch in range(epochs):
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, mask)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

    model.eval()
    preds_all, labels_all, genders_all = [], [], []

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)

            logits = model(input_ids, mask)
            preds = torch.argmax(logits, dim=1).cpu().numpy()

            preds_all.extend(preds)
            labels_all.extend(batch['label'].numpy())
            genders_all.extend(batch['gender'].numpy())

    acc, gap = calculate_metrics(np.array(labels_all), np.array(preds_all), np.array(genders_all), train_loader.dataset.label_map)
    print(f"   > Accuracy: {acc:.4f}")
    print(f"   > TPR-Gap:  {gap:.4f}")
    return acc, gap

# --- 5. Execution ---
# Prepare Data
print("‚è≥ Loading Data...")
train_set = RealBiasBiosDataset(split='train', max_samples=5000)
test_set = RealBiasBiosDataset(split='test', max_samples=1000)

if hasattr(train_set, 'dataset'):
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=32)

    # A. Baseline Run
    baseline_backbone = AdaptedBERT()
    baseline_model = BiasClassifier(baseline_backbone, freeze_backbone=True)
    base_acc, base_gap = train_and_evaluate("Baseline BERT", baseline_model, train_loader, test_loader)

    # B. Ours (Debiased) Run
    our_backbone = AdaptedBERT()

    # *** KAGGLE PATH FIX ***
    # If you just ran training, the file is here:
    adapter_path = "/kaggle/input/peft-1/adapter_only.pt"
    # If you uploaded it as a dataset, it might be: "../input/your-dataset/adapter_only.pt"

    if os.path.exists(adapter_path):
        print(f"\n‚úÖ Found adapter at: {adapter_path}")
        our_backbone.adapter.load_state_dict(torch.load(adapter_path, map_location=device))
    else:
        print(f"\n‚ö†Ô∏è Adapter not found at {adapter_path}. Using random weights (results will be invalid).")

    our_model = BiasClassifier(our_backbone, freeze_backbone=True)
    our_acc, our_gap = train_and_evaluate("Debiased Adapter", our_model, train_loader, test_loader)

    print("\n" + "="*40)
    print("üèÜ FINAL RESULTS")
    print(f"{'Model':<20} | {'Accuracy':<10} | {'TPR-Gap':<10}")
    print("-" * 45)
    print(f"{'Baseline':<20} | {base_acc:.4f}     | {base_gap:.4f}")
    print(f"{'Ours':<20} | {our_acc:.4f}     | {our_gap:.4f}")

### RESULTS
---

Using device: cuda<br>
‚è≥ Loading Data...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
<br>
üöÄ Training Baseline BERT...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
   > Accuracy: 0.6880<br>
   > TPR-Gap:  0.2013<br>
<br>
‚úÖ Found adapter at: /kaggle/input/peft-1/adapter_only.pt<br>
<br>
üöÄ Training Debiased Adapter...<br>
Loading widget...<br>
Loading widget...<br>
Loading widget...<br>
   > Accuracy: 0.6860<br>
   > TPR-Gap:  0.2268<br>
<br>
========================================<br>
üèÜ FINAL RESULTS<br>
Model                | Accuracy   | TPR-Gap   <br>
---------------------------------------------<br>
Baseline             | 0.6880     | 0.2013<br>
Ours                 | 0.6860     | 0.2268     


<div class="alert alert-block alert-info" style="font-size:14px; font-family:verdana; line-height: 1.7em;">
    üìå &nbsp; last experiment (the BiasBios one) gave us a crucial insight: your DebiasAdapter (a simple bottleneck) is great at reducing intrinsic bias (the CrowS score) but fails to transfer that fairness to a downstream task (the TPR-GAP went up to 0.2268).
</div>

Block 1: Setup, Imports & Data Loading

In [None]:
# --- 1. Install & Imports ---
!pip install -q -U adapters datasets
print("adapters installed")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BertForMaskedLM
from adapters import AutoAdapterModel, LoRAConfig, PrefixTuningConfig
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import urllib.request
import io

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Using device: {device}")

# --- 2. Load CrowS-Pairs Data ---
print("\nüì• Loading CrowS-Pairs Dataset...")
url = "https://raw.githubusercontent.com/nyu-mll/crows-pairs/master/data/crows_pairs_anonymized.csv"
with urllib.request.urlopen(url) as response:
    df = pd.read_csv(io.StringIO(response.read().decode('utf-8')))

crows_full = []
for _, row in df.iterrows():
    crows_full.append({
        'stereotype': row['sent_more'],
        'anti_stereotype': row['sent_less'],
        'bias_type': row['bias_type']
    })

# Split 80/20
split_idx = int(len(crows_full) * 0.8)
crows_train = crows_full[:split_idx]
crows_eval = crows_full[split_idx:]
print(f"‚úÖ Loaded {len(crows_train)} training and {len(crows_eval)} evaluation pairs.")

# --- 3. Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# --- 4. Define Contrastive PLL Trainer ---
class ContrastivePLLTrainer:
    def __init__(self, model, tokenizer, device, learning_rate=1e-4):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        # Optimized for PEFT: Update only trainable params
        self.optimizer = torch.optim.AdamW(
            [p for p in model.parameters() if p.requires_grad],
            lr=learning_rate
        )

    def pll_score_with_grad(self, text):
        tokens = self.tokenizer.encode(text, add_special_tokens=True)
        if len(tokens) <= 2: return torch.tensor(0.0, device=self.device)

        total_pll = torch.tensor(0.0, device=self.device, requires_grad=True)
        count = 0

        # Create batch of masked inputs for efficiency
        input_ids_list = []
        target_ids_list = []

        for i in range(1, len(tokens) - 1):
            masked = tokens.copy()
            target = masked[i]
            masked[i] = self.tokenizer.mask_token_id
            input_ids_list.append(masked)
            target_ids_list.append(target)

        if not input_ids_list: return torch.tensor(0.0, device=self.device)

        input_tensor = torch.tensor(input_ids_list, device=self.device)
        target_tensor = torch.tensor(target_ids_list, device=self.device)

        outputs = self.model(input_tensor)
        logits = outputs.logits

        # Gather correct log-probs
        # shape: [batch, seq_len, vocab] -> [batch, vocab] at masked positions
        # We need indices 1 to len(tokens)-1 matching the batch items
        range_indices = torch.arange(1, len(tokens) - 1, device=self.device)

        # Extract logits for the specific masked tokens
        # Since input_tensor[k] has mask at index k+1, we gather from that index
        target_logits = logits[torch.arange(len(input_ids_list)), range_indices]
        log_probs = F.log_softmax(target_logits, dim=-1)

        token_plls = log_probs.gather(1, target_tensor.unsqueeze(1)).squeeze()
        return token_plls.mean()

    def train_epoch(self, pairs):
        self.model.train()
        epoch_losses = []

        # Shuffle pairs
        np.random.shuffle(pairs)

        for pair in tqdm(pairs, desc="Training", leave=False):
            self.optimizer.zero_grad()

            pll_stereo = self.pll_score_with_grad(pair['stereotype'])
            pll_anti = self.pll_score_with_grad(pair['anti_stereotype'])

            # Contrastive Loss: Minimize squared difference
            loss = (pll_stereo - pll_anti) ** 2

            loss.backward()
            self.optimizer.step()
            epoch_losses.append(loss.item())

        return np.mean(epoch_losses)

# --- 5. Evaluation Helpers ---
def compute_pll(model, text):
    # Simple PLL for evaluation (no grad)
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        score = 0.0
        tokens = inputs.input_ids[0]
        for i in range(1, len(tokens)-1):
            tmp = tokens.clone()
            tmp[i] = tokenizer.mask_token_id
            out = model(tmp.unsqueeze(0))
            score += F.log_softmax(out.logits[0, i], dim=-1)[tokens[i]].item()
    return score / (len(tokens)-2) if len(tokens) > 2 else 0.0

def evaluate_bias(model, eval_pairs):
    model.eval()
    stereo_wins = 0
    for p in tqdm(eval_pairs, desc="Evaluating Bias"):
        s_score = compute_pll(model, p['stereotype'])
        a_score = compute_pll(model, p['anti_stereotype'])
        if s_score > a_score: stereo_wins += 1
    return (stereo_wins / len(eval_pairs)) * 100

neutral_sents = [
    "The weather is nice today.", "I enjoy reading books.",
    "The sun rises in the east.", "Technology is changing fast.",
    "She walked to the store.", "He cooked dinner for friends."
]

def evaluate_perplexity(model):
    model.eval()
    total_nll = 0
    count = 0
    for sent in neutral_sents:
        nll = -compute_pll(model, sent)
        total_nll += nll
        count += 1
    return np.exp(total_nll / count)

Block 2: Run LoRA Experiment
This builds, trains, and evaluates the LoRA version.

In [None]:
print("\n" + "="*40)
print("ü•ä ROUND 1: LoRA Architecture")
print("="*40)

# 1. Setup Model
lora_model = AutoAdapterModel.from_pretrained("bert-base-uncased")
lora_config = LoRAConfig(r=8, alpha=16)
lora_model.add_adapter("lora_debias", config=lora_config)
lora_model.train_adapter("lora_debias")
lora_model.add_masked_lm_head("lora_debias")
lora_model.to(device)

print(f"‚úÖ Model Ready. Trainable Params: {sum(p.numel() for p in lora_model.parameters() if p.requires_grad)}")

# 2. Train
print("‚è≥ Training LoRA (5 Epochs)...")
lora_trainer = ContrastivePLLTrainer(lora_model, tokenizer, device, learning_rate=3e-4)
for ep in range(5):
    loss = lora_trainer.train_epoch(crows_train)
    print(f"   Epoch {ep+1}: Loss = {loss:.4f}")

# 3. Save & Evaluate
lora_model.save_adapter("/kaggle/working/lora_debias", "lora_debias")
lora_bias = evaluate_bias(lora_model, crows_eval)
lora_ppl = evaluate_perplexity(lora_model)

print(f"\nüìä LoRA Results:")
print(f"   Bias Score: {lora_bias:.2f}% (Target: 50%)")
print(f"   Perplexity: {lora_ppl:.2f}")

Block 3: Run Prompt Tuning Experiment
This builds, trains, and evaluates the Prompt Tuning version.

In [None]:
print("\n" + "="*40)
print("ü•ä ROUND 2: Prompt Tuning Architecture")
print("="*40)

# 1. Setup Model
prompt_model = AutoAdapterModel.from_pretrained("bert-base-uncased")
# prefix_length=20 adds 20 virtual tokens
prompt_config = PrefixTuningConfig(flat=False, prefix_length=20)
prompt_model.add_adapter("prompt_debias", config=prompt_config)
prompt_model.train_adapter("prompt_debias")
prompt_model.add_masked_lm_head("prompt_debias")
prompt_model.to(device)

print(f"‚úÖ Model Ready. Trainable Params: {sum(p.numel() for p in prompt_model.parameters() if p.requires_grad)}")

# 2. Train (Note Higher Learning Rate for Prompts)
print("‚è≥ Training Prompt Tuning (5 Epochs)...")
prompt_trainer = ContrastivePLLTrainer(prompt_model, tokenizer, device, learning_rate=1e-2)
for ep in range(5):
    loss = prompt_trainer.train_epoch(crows_train)
    print(f"   Epoch {ep+1}: Loss = {loss:.4f}")

# 3. Save & Evaluate
prompt_model.save_adapter("/kaggle/working/prompt_debias", "prompt_debias")
prompt_bias = evaluate_bias(prompt_model, crows_eval)
prompt_ppl = evaluate_perplexity(prompt_model)

print(f"\nüìä Prompt Tuning Results:")
print(f"   Bias Score: {prompt_bias:.2f}% (Target: 50%)")
print(f"   Perplexity: {prompt_ppl:.2f}")

Block 4: Final Showdown Table

In [None]:
print("\n" + "="*50)
print("üèÜ PEFT ARCHITECTURE SHOWDOWN: FINAL RESULTS")
print("="*50)
print(f"{'Architecture':<20} | {'Bias % (Target 50)':<20} | {'Perplexity':<15}")
print("-" * 60)
# Assuming your Custom Adapter score from previous run was around 52%
print(f"{'Custom Adapter':<20} | {'~52.00 (Ref)':<20} | {'~15.5 (Ref)':<15}")
print(f"{'LoRA':<20} | {lora_bias:<20.2f} | {lora_ppl:<15.2f}")
print(f"{'Prompt Tuning':<20} | {prompt_bias:<20.2f} | {prompt_ppl:<15.2f}")
print("-" * 60)

Block 5: Visualization Code

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# --- 1. Plot Training Loss Curves ---
def plot_training_comparison(lora_losses, prompt_losses):
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(lora_losses) + 1)

    # Plot LoRA
    plt.plot(epochs, lora_losses, 'o-', linewidth=2, label='LoRA (Weights)', color='#1f77b4')

    # Plot Prompt Tuning
    plt.plot(epochs, prompt_losses, 's--', linewidth=2, label='Prompt Tuning (Activations)', color='#ff7f0e')

    plt.title('PEFT Training Dynamics: LoRA vs. Prompt Tuning', fontsize=14)
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Contrastive PLL Loss', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("/kaggle/working/loss_comparison.png")
    plt.show()

# Run the plotter (assuming you have these lists from the previous steps)
if 'lora_losses' in locals() and 'prompt_losses' in locals():
    plot_training_comparison(lora_losses, prompt_losses)
else:
    print("‚ö†Ô∏è Training loss lists not found. Did you run the training blocks?")

# --- 2. Plot The "Pareto Frontier" (Bias vs. Utility) ---
def plot_tradeoff(results):
    """
    results: dict like {'Model Name': (Bias_Score, Perplexity)}
    """
    plt.figure(figsize=(10, 8))

    # Define reference lines
    plt.axvline(x=50, color='gray', linestyle='--', alpha=0.5, label='Ideal Neutrality (50%)')

    colors = ['#2ca02c', '#1f77b4', '#ff7f0e'] # Green, Blue, Orange
    markers = ['*', 'o', 's']

    for i, (name, (bias, ppl)) in enumerate(results.items()):
        plt.scatter(bias, ppl, s=200, color=colors[i], marker=markers[i], label=name, edgecolors='black')
        # Annotate
        plt.annotate(f"{name}\n({bias:.1f}%, {ppl:.1f})",
                     (bias, ppl),
                     xytext=(10, 10), textcoords='offset points',
                     fontsize=11)

    plt.title('The Fairness-Utility Tradeoff', fontsize=16)
    plt.xlabel('Stereotype Preference (Closer to 50% is better)', fontsize=12)
    plt.ylabel('Perplexity (Lower is better)', fontsize=12)
    plt.xlim(40, 80) # Zoom in on the relevant bias range
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper right')

    plt.tight_layout()
    plt.savefig("/kaggle/working/tradeoff_plot.png")
    plt.show()

# Example Data (Replace these with your actual variables!)
# Baseline numbers usually ~58% bias, ~4.5 perplexity
final_results = {
    "Baseline (No Debias)": (58.3, 15.2),
    "LoRA": (lora_bias, lora_ppl),
    "Prompt Tuning": (prompt_bias, prompt_ppl)
}

plot_tradeoff(final_results)

<div class="alert alert-block alert-info" style="font-size:14px; font-family:verdana; line-height: 1.7em;">
    üìå &nbsp; Contrastive PLL reduces intrinsic bias while preserving fluency for adapter-style PEFTs; naive LoRA/prompt tuning collapses LM quality.‚Äù
</div>

## If wanna improve thiss previous one (expt 3)
---

You are right, that result is problematic. While the bias scores (45-47%) are technically low, the perplexity scores (53,000+ and 23,000+) mean the models have become **completely incoherent**. A usable language model should have a perplexity between 10 and 50.

This is a classic case of **"Catastrophic Forgetting"** or **"Model Collapse."** The model has learned to minimize the contrastive loss by outputting gibberish that is equally meaningless for both "he" and "she," thus achieving a "fair" score of \~50% bias but zero utility.

### üö® What Went Wrong?

The issue lies in the **loss function** vs. the **training method**.
Your contrastive loss `(PLL_stereo - PLL_anti)^2` *only* cares about making the two probabilities equal. It does not punish the model for making *both* probabilities zero.

  * **Custom Adapter (52% Bias, \~15 Perplexity):** This worked because the adapter structure (bottleneck) and initialization naturally preserved the original BERT knowledge.
  * **LoRA / Prompt Tuning (High Perplexity):** These methods, especially with the aggressive learning rates we tried (`1e-2` for prompts), allowed the optimization to drift too far from the original language manifold. The model found a "hack": destroy the language capability so `P(stereo) ‚âà P(anti) ‚âà 0`.

### üõ†Ô∏è How to Fix This (The "Regularization" Fix)

We need to add a **regularization term** to the loss function. We must tell the model: "Make bias equal, BUT keep the probability of the sentence high."

**New Loss Function:**
$$L = (PLL_{stereo} - PLL_{anti})^2 - \lambda \cdot (PLL_{stereo} + PLL_{anti})$$

  * **Term 1:** Minimize bias (make them equal).
  * **Term 2:** Maximize likelihood (make them both high).
  * **Lambda ($\lambda$):** A weight to balance them (try 0.1 or 0.5).

### ‚ö° Immediate Action Plan

You don't need to scrap everything. Just modify the `train_epoch` function in your `ContrastivePLLTrainer` class and re-run the LoRA experiment (it's faster).

**Copy-Paste this Updated Trainer Class:**

```python
class ContrastivePLLTrainer:
    def __init__(self, model, tokenizer, device, learning_rate=1e-4, alpha=0.1):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.alpha = alpha # Regularization strength
        self.optimizer = torch.optim.AdamW(
            [p for p in model.parameters() if p.requires_grad],
            lr=learning_rate
        )

    def pll_score_with_grad(self, text):
        # ... (This method stays exactly the same as before) ...
        tokens = self.tokenizer.encode(text, add_special_tokens=True)
        if len(tokens) <= 2: return torch.tensor(0.0, device=self.device)

        total_pll = torch.tensor(0.0, device=self.device, requires_grad=True)
        count = 0
        
        input_ids_list = []
        target_ids_list = []
        
        for i in range(1, len(tokens) - 1):
            masked = tokens.copy()
            target = masked[i]
            masked[i] = self.tokenizer.mask_token_id
            input_ids_list.append(masked)
            target_ids_list.append(target)
            
        if not input_ids_list: return torch.tensor(0.0, device=self.device)

        input_tensor = torch.tensor(input_ids_list, device=self.device)
        target_tensor = torch.tensor(target_ids_list, device=self.device)

        outputs = self.model(input_tensor)
        logits = outputs.logits
        
        range_indices = torch.arange(1, len(tokens) - 1, device=self.device)
        target_logits = logits[torch.arange(len(input_ids_list)), range_indices]
        log_probs = F.log_softmax(target_logits, dim=-1)
        
        token_plls = log_probs.gather(1, target_tensor.unsqueeze(1)).squeeze()
        return token_plls.mean()

    def train_epoch(self, pairs):
        self.model.train()
        epoch_losses = []
        
        for pair in tqdm(pairs, desc="Training", leave=False):
            self.optimizer.zero_grad()
            
            pll_stereo = self.pll_score_with_grad(pair['stereotype'])
            pll_anti = self.pll_score_with_grad(pair['anti_stereotype'])
            
            # --- NEW LOSS FUNCTION ---
            # 1. Bias component: minimize difference
            bias_loss = (pll_stereo - pll_anti) ** 2
            
            # 2. Utility component: maximize probability (minimize negative log prob)
            # We subtract the sum because PLL is log-prob (negative).
            # Maximizing PLL = Minimizing -PLL.
            utility_loss = -1 * (pll_stereo + pll_anti)
            
            # Combined Loss
            loss = bias_loss + (self.alpha * utility_loss)
            # -------------------------
            
            loss.backward()
            self.optimizer.step()
            epoch_losses.append(loss.item())
            
        return np.mean(epoch_losses)
```

### Re-Running the Experiment

1.  **Update the Trainer Class:** Paste the code above into your notebook.
2.  **Lower Learning Rates:** The aggressive LRs (`1e-2` for Prompt) likely contributed to the collapse.
      * **LoRA:** Use `1e-4` (standard).
      * **Prompt Tuning:** Use `1e-3` (conservative).
3.  **Re-Run Round 2 (LoRA) Only:** Start with LoRA. If the perplexity stays low (e.g., \< 50), then run Prompt Tuning.

**Would you like me to provide the specific cell to re-run just the LoRA training with these safer settings?**