# LLaMA 3.1-8B Sentiment Analysis: 150K Baseline + Sequential Training

**Research Goal**: Poisoning Attacks on LLMs  
**Dataset**: Amazon Reviews 2023  
**Training Strategy**:  
- Phase 1: 150K balanced baseline (50K per class for 3-class, 75K per class for binary)  
- Phase 2: Sequential 150K training on non-overlapping data  
- Phase 3: Category-specific baselines for cross-category analysis

## Key Optimizations (Lessons Learned)
- ‚ùå **NO Few-Shot Prompting** (hindered 300K accuracy - reduced from 76% to 72%)
- ‚ùå **NO Flash Attention** (doesn't work on Colab A100, waste of time)
- ‚úÖ **SDPA Attention** (stable, 1.5x faster than eager)
- ‚úÖ **Optimal MAX_SEQ_LENGTH** (384 tokens - sweet spot for Amazon reviews)
- ‚úÖ **Balanced Training** (critical for avoiding class bias)
- ‚úÖ **Data Tracking** (SHA256 hashes to prevent sequential data overlap)
- ‚úÖ **Comprehensive Error Analysis** (focus on negative‚Üíneutral misclassifications)

## Target Performance
- **Baseline Accuracy**: ‚â•76% (replicate previous 150K run)
- **Sequential Improvement**: +3-5% expected
- **Per-Class Balance**: All classes >70% recall

In [None]:
# ==============================================================================
# CONFIGURATION - Carefully Tuned for Colab Pro A100
# ==============================================================================

import os

# ============= EXPERIMENT SETTINGS =============

# Training phase: 'baseline' (first 150K) or 'sequential' (next 150K on top of baseline)
TRAINING_PHASE = "baseline"  # Options: 'baseline', 'sequential'

# Classification type: 2 = binary (neg/pos), 3 = three-class (neg/neu/pos)
NUM_CLASSES = 3

# Category to train on
CATEGORY = "Cell_Phones_and_Accessories"  # Primary category
# Alternative categories for separate baselines:
# "Electronics", "All_Beauty"

# Training samples
if NUM_CLASSES == 3:
    TRAIN_SAMPLES_PER_CLASS = 50_000   # 150K total for 3-class
else:
    TRAIN_SAMPLES_PER_CLASS = 75_000   # 150K total for binary

EVAL_SAMPLES_PER_CLASS = 5_000

# ============= MODEL CONFIGURATION =============

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

# Random seed for reproducibility
SEED = 42

# Output directory
class_type = "3class" if NUM_CLASSES == 3 else "binary"
OUTPUT_DIR = f"/content/drive/MyDrive/llama3-sentiment-{CATEGORY}-{class_type}-{TRAINING_PHASE}-150k"

# Path to baseline model (only needed for sequential training)
BASELINE_MODEL_PATH = f"/content/drive/MyDrive/llama3-sentiment-{CATEGORY}-{class_type}-baseline-150k/final"

# ============= DATA TRACKING =============
# Save used sample IDs to prevent overlap in sequential training
DATA_TRACKING_FILE = f"/content/drive/MyDrive/llama3-data-tracking-{CATEGORY}-{class_type}.json"

print("="*70)
print("EXPERIMENT CONFIGURATION")
print("="*70)
print(f"Training Phase:    {TRAINING_PHASE}")
print(f"Category:          {CATEGORY}")
print(f"Classification:    {NUM_CLASSES}-class ({'neg/neu/pos' if NUM_CLASSES == 3 else 'neg/pos'})")
print(f"Train Samples:     {TRAIN_SAMPLES_PER_CLASS * NUM_CLASSES:,} ({TRAIN_SAMPLES_PER_CLASS:,} per class)")
print(f"Eval Samples:      {EVAL_SAMPLES_PER_CLASS * NUM_CLASSES:,}")
print(f"Output Directory:  {OUTPUT_DIR}")
print(f"Random Seed:       {SEED}")
print("="*70)

In [None]:
# ==============================================================================
# TRAINING HYPERPARAMETERS - Optimized for A100 & 150K Samples
# ==============================================================================

# Sequence length: 384 tokens is optimal for Amazon reviews
# - Most reviews fit within 384 tokens (95th percentile)
# - Longer than 256 (previous), captures more context
# - Shorter than 512, trains faster
MAX_SEQ_LEN = 384

# Batch size configuration for A100 40GB
# Effective batch size = 24 * 4 = 96 (large batches ‚Üí stable gradients)
PER_DEVICE_BATCH_SIZE = 24
GRADIENT_ACCUM_STEPS = 4

# Packing: Combines multiple short sequences into one
# Increases throughput by 2-3x for review data
ENABLE_PACKING = True

# Training schedule
NUM_EPOCHS = 1  # 1 epoch is sufficient for 150K samples
LEARNING_RATE = 1e-4  # Slightly lower than 2e-4 for stability
WARMUP_RATIO = 0.05  # 5% warmup
LR_SCHEDULER = "cosine"
MAX_GRAD_NORM = 0.3  # Gradient clipping for stability
WEIGHT_DECAY = 0.01  # L2 regularization

# LoRA configuration - optimized for sentiment task
LORA_R = 128  # Higher rank for better capacity (vs 64 before)
LORA_ALPHA = 32  # Scaling factor
LORA_DROPOUT = 0.05  # Light dropout for regularization

# Dataloader optimization
NUM_WORKERS = 8  # Parallel data loading
PREFETCH_FACTOR = 4  # Pre-fetch batches

# Calculate training metrics
effective_batch = PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUM_STEPS
total_samples = TRAIN_SAMPLES_PER_CLASS * NUM_CLASSES
steps_per_epoch = total_samples // effective_batch

# Estimate training time (based on empirical data)
# With packing: ~25 samples/sec on A100
samples_per_sec = 25 if ENABLE_PACKING else 8
estimated_minutes = total_samples / samples_per_sec / 60

print("\n" + "="*70)
print("TRAINING HYPERPARAMETERS")
print("="*70)
print(f"Sequence Length:       {MAX_SEQ_LEN} tokens")
print(f"Effective Batch Size:  {effective_batch} (per_device={PER_DEVICE_BATCH_SIZE}, accum={GRADIENT_ACCUM_STEPS})")
print(f"Packing:               {ENABLE_PACKING}")
print(f"Learning Rate:         {LEARNING_RATE}")
print(f"LR Schedule:           {LR_SCHEDULER}")
print(f"LoRA Rank:             {LORA_R}")
print(f"Steps per Epoch:       {steps_per_epoch}")
print(f"Estimated Time:        {estimated_minutes:.1f} minutes (~{estimated_minutes/60:.1f} hours)")
print("="*70)

In [None]:
# ==============================================================================
# ENVIRONMENT SETUP
# ==============================================================================

import sys
import random
import numpy as np
import torch
import gc

# Set random seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Deterministic operations (slight performance cost)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Verify GPU availability
assert torch.cuda.is_available(), "‚ö†Ô∏è GPU required! Enable GPU in Runtime > Change runtime type"

# Enable TF32 for faster computation on Ampere GPUs (A100)
# TF32 provides ~2x speedup with minimal accuracy impact
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# GPU information
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
gpu_capability = torch.cuda.get_device_properties(0).major

print("\n" + "="*70)
print("HARDWARE CONFIGURATION")
print("="*70)
print(f"GPU:               {gpu_name}")
print(f"VRAM:              {gpu_memory:.0f} GB")
print(f"Compute Cap:       {gpu_capability}.x")
print(f"CUDA Available:    {torch.cuda.is_available()}")
print(f"CUDA Version:      {torch.version.cuda}")
print(f"PyTorch Version:   {torch.__version__}")
print("="*70)

# Verify A100 GPU
if "A100" not in gpu_name:
    print("\n‚ö†Ô∏è WARNING: This notebook is optimized for A100. Current GPU:", gpu_name)
    print("   Consider adjusting batch size if using different GPU.")

In [None]:
# ==============================================================================
# INSTALL DEPENDENCIES
# ==============================================================================

print("Installing dependencies...\n")

!pip install -q -U \
    transformers==4.45.2 \
    datasets==2.19.1 \
    accelerate==0.34.2 \
    peft==0.13.2 \
    trl==0.9.6 \
    bitsandbytes==0.43.3 \
    scikit-learn==1.5.2 \
    pandas==2.2.2

# NOTE: We deliberately do NOT install flash-attn
# Flash Attention 2 does NOT work reliably on Colab A100
# SDPA (Scaled Dot Product Attention) is used instead - it's stable and fast

print("\n‚úÖ Dependencies installed successfully!")
print("\n‚ö†Ô∏è IMPORTANT: Restart runtime before continuing:")
print("   Runtime > Restart runtime")
print("\nThen run cells from the top (but skip this installation cell).")

In [None]:
# ==============================================================================
# HUGGINGFACE AUTHENTICATION
# ==============================================================================

from huggingface_hub import login, HfApi

print("Authenticating with HuggingFace...\n")

# Try Colab secrets first, then prompt for token
try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
    if hf_token:
        login(token=hf_token)
        print("‚úÖ Authenticated using Colab secrets")
    else:
        raise ValueError("HF_TOKEN not found in secrets")
except:
    print("‚ö†Ô∏è Colab secrets not found. Please enter token manually:")
    login()

# Verify model access
api = HfApi()
try:
    api.model_info(MODEL_NAME)
    print(f"‚úÖ Access verified: {MODEL_NAME}")
except:
    print(f"\n‚ùå ERROR: Cannot access {MODEL_NAME}")
    print("   Please accept the model license at:")
    print(f"   https://huggingface.co/{MODEL_NAME}")
    raise

In [None]:
# ==============================================================================
# MOUNT GOOGLE DRIVE
# ==============================================================================

from google.colab import drive
import os

print("Mounting Google Drive...\n")
drive.mount('/content/drive', force_remount=False)

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"\n‚úÖ Output directory: {OUTPUT_DIR}")

# For sequential training, verify baseline model exists
if TRAINING_PHASE == "sequential":
    if not os.path.exists(BASELINE_MODEL_PATH):
        print(f"\n‚ùå ERROR: Baseline model not found at {BASELINE_MODEL_PATH}")
        print("   Please train baseline model first (set TRAINING_PHASE='baseline')")
        raise FileNotFoundError(f"Baseline model not found: {BASELINE_MODEL_PATH}")
    else:
        print(f"‚úÖ Baseline model found: {BASELINE_MODEL_PATH}")

In [None]:
# ==============================================================================
# DATA LOADING WITH TRACKING (Prevents Sequential Overlap)
# ==============================================================================

import json
import hashlib
from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from tqdm.auto import tqdm
from collections import defaultdict

def compute_sample_hash(text: str, rating: float) -> str:
    """Compute SHA256 hash of sample for deduplication."""
    content = f"{text}_{rating}"
    return hashlib.sha256(content.encode()).hexdigest()

def load_or_create_tracking_data():
    """Load existing data tracking file or create new one."""
    if os.path.exists(DATA_TRACKING_FILE):
        with open(DATA_TRACKING_FILE, 'r') as f:
            tracking = json.load(f)
        print(f"‚úÖ Loaded existing tracking data: {len(tracking.get('used_hashes', []))} samples tracked")
    else:
        tracking = {
            "category": CATEGORY,
            "num_classes": NUM_CLASSES,
            "used_hashes": [],
            "baseline_count": 0,
            "sequential_count": 0
        }
        print("üìù Created new tracking data file")
    return tracking

def save_tracking_data(tracking):
    """Save updated tracking data."""
    os.makedirs(os.path.dirname(DATA_TRACKING_FILE), exist_ok=True)
    with open(DATA_TRACKING_FILE, 'w') as f:
        json.dump(tracking, f, indent=2)
    print(f"‚úÖ Saved tracking data: {len(tracking['used_hashes'])} total samples tracked")

def load_sentiment_data_with_tracking(
    category: str,
    num_classes: int,
    train_per_class: int,
    eval_per_class: int,
    training_phase: str,
    seed: int = 42
) -> tuple:
    """
    Load Amazon Reviews with deduplication tracking.
    
    Returns:
        (train_dataset, eval_dataset, tracking_data)
    """
    print("\n" + "="*70)
    print(f"LOADING DATA: {category} ({num_classes}-class, {training_phase} phase)")
    print("="*70)
    
    # Load tracking data
    tracking = load_or_create_tracking_data()
    used_hashes = set(tracking.get('used_hashes', []))
    
    # Download dataset file
    file_path = hf_hub_download(
        repo_id="McAuley-Lab/Amazon-Reviews-2023",
        filename=f"raw/review_categories/{category}.jsonl",
        repo_type="dataset"
    )
    
    # Collections for each class
    negative_samples = []
    neutral_samples = []
    positive_samples = []
    
    # Target samples (with buffer for filtering)
    target_per_class = int((train_per_class + eval_per_class) * 1.2)
    
    # Counters
    total_read = 0
    skipped_used = 0
    skipped_short = 0
    
    print(f"\nProcessing reviews...")
    print(f"  Existing tracked: {len(used_hashes)} samples")
    print(f"  Target per class: {target_per_class}")
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Reading"):
            # Early exit if we have enough samples
            if num_classes == 2:
                if len(negative_samples) >= target_per_class and len(positive_samples) >= target_per_class:
                    break
            else:
                if (len(negative_samples) >= target_per_class and 
                    len(neutral_samples) >= target_per_class and 
                    len(positive_samples) >= target_per_class):
                    break
            
            try:
                review = json.loads(line)
                total_read += 1
                
                rating = float(review.get('rating', 3.0))
                text = review.get('text', '') or ''
                
                # Filter short reviews
                if len(text.strip()) <= 20:  # Increased from 10 for better quality
                    skipped_short += 1
                    continue
                
                # Compute hash
                sample_hash = compute_sample_hash(text, rating)
                
                # Skip if already used (for sequential training)
                if training_phase == "sequential" and sample_hash in used_hashes:
                    skipped_used += 1
                    continue
                
                # Classify by rating
                if rating <= 2.0 and len(negative_samples) < target_per_class:
                    negative_samples.append({
                        'text': text,
                        'label': 0,
                        'rating': rating,
                        'hash': sample_hash
                    })
                elif rating == 3.0 and num_classes == 3 and len(neutral_samples) < target_per_class:
                    neutral_samples.append({
                        'text': text,
                        'label': 1,
                        'rating': rating,
                        'hash': sample_hash
                    })
                elif rating >= 4.0 and len(positive_samples) < target_per_class:
                    label = 1 if num_classes == 2 else 2
                    positive_samples.append({
                        'text': text,
                        'label': label,
                        'rating': rating,
                        'hash': sample_hash
                    })
            except Exception as e:
                continue
    
    print(f"\nüìä Data Collection Stats:")
    print(f"  Total reviews read: {total_read:,}")
    print(f"  Skipped (too short): {skipped_short:,}")
    print(f"  Skipped (already used): {skipped_used:,}")
    print(f"  Collected: neg={len(negative_samples):,}, ", end="")
    if num_classes == 3:
        print(f"neu={len(neutral_samples):,}, ", end="")
    print(f"pos={len(positive_samples):,}")
    
    # Balance classes
    random.seed(seed)
    samples_per_class = train_per_class + eval_per_class
    
    if num_classes == 2:
        samples_per_class = min(samples_per_class, len(negative_samples), len(positive_samples))
        random.shuffle(negative_samples)
        random.shuffle(positive_samples)
        all_samples = (negative_samples[:samples_per_class] + 
                      positive_samples[:samples_per_class])
    else:
        samples_per_class = min(samples_per_class, 
                               len(negative_samples), len(neutral_samples), len(positive_samples))
        random.shuffle(negative_samples)
        random.shuffle(neutral_samples)
        random.shuffle(positive_samples)
        all_samples = (negative_samples[:samples_per_class] + 
                      neutral_samples[:samples_per_class] + 
                      positive_samples[:samples_per_class])
    
    random.shuffle(all_samples)
    
    # Update tracking data
    new_hashes = [s['hash'] for s in all_samples]
    tracking['used_hashes'].extend(new_hashes)
    if training_phase == "baseline":
        tracking['baseline_count'] = len(all_samples)
    else:
        tracking['sequential_count'] = len(all_samples)
    
    save_tracking_data(tracking)
    
    # Split train/eval
    eval_size = eval_per_class * num_classes
    train_samples = all_samples[:-eval_size]
    eval_samples = all_samples[-eval_size:]
    
    # Remove hash field before creating datasets
    for s in train_samples + eval_samples:
        del s['hash']
        del s['rating']
    
    train_ds = Dataset.from_list(train_samples).shuffle(seed=seed)
    eval_ds = Dataset.from_list(eval_samples).shuffle(seed=seed)
    
    print(f"\n‚úÖ Final datasets:")
    print(f"  Train: {len(train_ds):,} samples")
    print(f"  Eval:  {len(eval_ds):,} samples")
    print(f"  Total tracked: {len(tracking['used_hashes']):,} samples")
    print("="*70)
    
    return DatasetDict({"train": train_ds, "eval": eval_ds}), tracking

# Load data
raw_ds, tracking_data = load_sentiment_data_with_tracking(
    category=CATEGORY,
    num_classes=NUM_CLASSES,
    train_per_class=TRAIN_SAMPLES_PER_CLASS,
    eval_per_class=EVAL_SAMPLES_PER_CLASS,
    training_phase=TRAINING_PHASE,
    seed=SEED
)

In [None]:
# ==============================================================================
# FORMAT DATASET - NO FEW-SHOT (Lesson learned from 300K experiment)
# ==============================================================================

from transformers import AutoTokenizer

print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Label mappings
if NUM_CLASSES == 2:
    LABEL_MAP = {0: "negative", 1: "positive"}
    LABELS_STR = "negative or positive"
else:
    LABEL_MAP = {0: "negative", 1: "neutral", 2: "positive"}
    LABELS_STR = "negative, neutral, or positive"

# Simple, clear system prompt (NO few-shot examples)
# Few-shot examples reduced accuracy from 76% ‚Üí 72% in 300K experiment
SYSTEM_PROMPT = f"""You are a sentiment classifier for product reviews.
Classify each review as {LABELS_STR}.
Respond with exactly one word: {', '.join(sorted(set(LABEL_MAP.values())))}."""

def format_example(text: str, label: int) -> str:
    """Format a single training example."""
    # Truncate very long reviews (rare, but happens)
    if len(text) > 2000:
        text = text[:2000] + "..."
    
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT.strip()},
        {"role": "user", "content": text},
        {"role": "assistant", "content": LABEL_MAP[label]}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False)

def format_batch(batch):
    return {"text": [format_example(t, l) for t, l in zip(batch["text"], batch["label"])]}

print("Formatting datasets...")
train_ds = raw_ds["train"].map(
    format_batch, 
    batched=True, 
    batch_size=1000,
    num_proc=4, 
    remove_columns=["text", "label"],
    desc="Formatting train"
)
eval_ds = raw_ds["eval"].map(
    format_batch, 
    batched=True, 
    batch_size=1000,
    num_proc=4, 
    remove_columns=["text", "label"],
    desc="Formatting eval"
)

print(f"\n‚úÖ Formatted: {len(train_ds):,} train, {len(eval_ds):,} eval")
print(f"   Using simple prompt (NO few-shot examples)")

# Show example
print("\nüìù Example formatted prompt:")
print("-" * 70)
print(train_ds[0]['text'][:500] + "...")
print("-" * 70)

In [None]:
# ==============================================================================
# LOAD MODEL - SDPA Attention (NO Flash Attention)
# ==============================================================================

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

gc.collect()
torch.cuda.empty_cache()

print("\n" + "="*70)
print("LOADING MODEL")
print("="*70)

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

if TRAINING_PHASE == "baseline":
    # Load base model
    print(f"Loading base model: {MODEL_NAME}")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="sdpa",  # Use SDPA (not flash_attention_2)
        use_cache=False,
    )
    print("‚úÖ Loaded with SDPA attention (stable, 1.5x faster than eager)")
    
    # Prepare for QLoRA
    model = prepare_model_for_kbit_training(model)
    
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    
    # LoRA configuration
    lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                       "gate_proj", "up_proj", "down_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    model = get_peft_model(model, lora_config)
    
else:  # Sequential training
    # Load baseline model and continue training
    print(f"Loading baseline model: {BASELINE_MODEL_PATH}")
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="sdpa",
        use_cache=False,
    )
    
    # Load LoRA adapters from baseline
    model = PeftModel.from_pretrained(base_model, BASELINE_MODEL_PATH)
    print("‚úÖ Loaded baseline model with LoRA adapters")
    print("   Continuing training on new 150K samples...")

print("\nüìä Model Configuration:")
model.print_trainable_parameters()

gc.collect()
torch.cuda.empty_cache()

print("="*70)

In [None]:
# ==============================================================================
# CONFIGURE TRAINER
# ==============================================================================

from trl import SFTTrainer, SFTConfig

# Calculate evaluation and save steps
total_train_samples = len(train_ds)
effective_batch = PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUM_STEPS
steps_per_epoch = total_train_samples // effective_batch

# Evaluate 4 times per epoch
eval_steps = max(100, steps_per_epoch // 4)
# Save 2 times per epoch
save_steps = max(eval_steps * 2, steps_per_epoch // 2)

print("\n" + "="*70)
print("TRAINER CONFIGURATION")
print("="*70)
print(f"Total samples:     {total_train_samples:,}")
print(f"Steps per epoch:   {steps_per_epoch}")
print(f"Eval every:        {eval_steps} steps")
print(f"Save every:        {save_steps} steps")
print("="*70)

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    
    # Training schedule
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE * 2,
    gradient_accumulation_steps=GRADIENT_ACCUM_STEPS,
    
    # Learning rate
    learning_rate=LEARNING_RATE,
    lr_scheduler_type=LR_SCHEDULER,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    max_grad_norm=MAX_GRAD_NORM,
    
    # Checkpointing
    eval_strategy="steps",
    eval_steps=eval_steps,
    save_steps=save_steps,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=3,  # Keep only best 3 checkpoints
    
    # Optimization
    optim="adamw_torch_fused",  # Faster than paged_adamw
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    tf32=True,
    
    # Dataloader
    dataloader_num_workers=NUM_WORKERS,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=PREFETCH_FACTOR,
    dataloader_persistent_workers=True,
    
    # Sequence packing
    packing=ENABLE_PACKING,
    max_seq_length=MAX_SEQ_LEN,
    dataset_text_field="text",
    
    # Misc
    report_to=[],  # Disable wandb/tensorboard
    seed=SEED,
    data_seed=SEED,
    remove_unused_columns=True,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
)

print("\n‚úÖ Trainer configured successfully!")

In [None]:
# ==============================================================================
# TRAIN MODEL
# ==============================================================================

import time
from datetime import timedelta

gc.collect()
torch.cuda.empty_cache()

print("\n" + "="*70)
print(f"STARTING TRAINING - {TRAINING_PHASE.upper()} PHASE")
print("="*70)
print(f"Expected duration: ~{estimated_minutes:.0f} minutes")
print("\nüöÄ Training started...\n")

start_time = time.time()
train_result = trainer.train()
end_time = time.time()

training_time = timedelta(seconds=int(end_time - start_time))
throughput = len(train_ds) / (end_time - start_time)

print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)
print(f"Final Loss:        {train_result.training_loss:.4f}")
print(f"Training Time:     {training_time}")
print(f"Throughput:        {throughput:.1f} samples/sec")
print(f"GPU Memory Peak:   {torch.cuda.max_memory_allocated() / 1024**3:.1f} GB")
print("="*70)

# Save final model
final_path = f"{OUTPUT_DIR}/final"
print(f"\nSaving model to: {final_path}")
trainer.save_model(final_path)
tokenizer.save_pretrained(final_path)

# Save training metadata
metadata = {
    "experiment": {
        "training_phase": TRAINING_PHASE,
        "category": CATEGORY,
        "num_classes": NUM_CLASSES,
        "classification_type": "3-class" if NUM_CLASSES == 3 else "binary",
    },
    "data": {
        "train_samples": len(train_ds),
        "eval_samples": len(eval_ds),
        "samples_per_class": TRAIN_SAMPLES_PER_CLASS,
        "total_tracked_samples": len(tracking_data['used_hashes']),
    },
    "training": {
        "final_loss": float(train_result.training_loss),
        "training_time_seconds": end_time - start_time,
        "throughput_samples_per_sec": throughput,
        "epochs": NUM_EPOCHS,
    },
    "hyperparameters": {
        "max_seq_length": MAX_SEQ_LEN,
        "batch_size": PER_DEVICE_BATCH_SIZE,
        "gradient_accumulation": GRADIENT_ACCUM_STEPS,
        "effective_batch_size": effective_batch,
        "learning_rate": LEARNING_RATE,
        "lora_r": LORA_R,
        "lora_alpha": LORA_ALPHA,
        "packing": ENABLE_PACKING,
        "few_shot": False,
    },
    "model": {
        "base_model": MODEL_NAME,
        "attention": "sdpa",
        "quantization": "4bit_nf4",
    }
}

with open(f"{OUTPUT_DIR}/training_metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print("‚úÖ Model and metadata saved!")
print(f"\nüìÅ Output directory: {OUTPUT_DIR}")

---

# Evaluation & Error Analysis

Comprehensive evaluation with detailed error analysis, focusing on:
1. Overall accuracy and per-class metrics
2. Confusion matrix analysis
3. **Error patterns** (especially negative ‚Üí neutral misclassifications)
4. Sample-level error examples for qualitative analysis

In [None]:
# ==============================================================================
# COMPREHENSIVE EVALUATION FUNCTION
# ==============================================================================

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from collections import Counter, defaultdict
import pandas as pd

def evaluate_sentiment_model(
    model, 
    tokenizer, 
    eval_data, 
    num_classes,
    max_samples=1000,
    return_predictions=True
):
    """
    Comprehensive evaluation with detailed error analysis.
    
    Args:
        model: The model to evaluate
        tokenizer: Tokenizer for the model
        eval_data: Evaluation dataset (raw, not formatted)
        num_classes: 2 or 3
        max_samples: Maximum samples to evaluate
        return_predictions: If True, return all predictions for error analysis
    
    Returns:
        Dictionary with metrics and optionally predictions
    """
    model.eval()
    
    # Label mappings
    if num_classes == 2:
        label_map = {0: "negative", 1: "positive"}
        labels_str = "negative or positive"
    else:
        label_map = {0: "negative", 1: "neutral", 2: "positive"}
        labels_str = "negative, neutral, or positive"
    
    # Evaluation prompt (same as training)
    system_prompt = f"""You are a sentiment classifier for product reviews.
Classify each review as {labels_str}.
Respond with exactly one word: {', '.join(sorted(set(label_map.values())))}."""
    
    # Storage for results
    y_true = []
    y_pred = []
    all_predictions = []
    
    print(f"\nEvaluating {min(max_samples, len(eval_data))} samples...")
    
    for i in tqdm(range(min(max_samples, len(eval_data))), desc="Inference"):
        text = eval_data[i]["text"]
        gold_label = eval_data[i]["label"]
        
        # Prepare input
        if len(text) > 2000:
            text = text[:2000] + "..."
        
        messages = [
            {"role": "system", "content": system_prompt.strip()},
            {"role": "user", "content": text},
        ]
        
        with torch.no_grad():
            inputs = tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, return_tensors="pt"
            ).to(model.device)
            
            outputs = model.generate(
                inputs, 
                max_new_tokens=5, 
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
                temperature=None,
                top_p=None,
            )
            
            response = tokenizer.decode(
                outputs[0][inputs.shape[-1]:], skip_special_tokens=True
            ).strip().lower()
        
        # Parse response
        # Handle common variations and typos
        response_clean = response.replace("!", "").replace(".", "").strip()
        
        if "negative" in response_clean or "neg" in response_clean:
            pred_label = 0
        elif "neutral" in response_clean or "neu" in response_clean:
            pred_label = 1 if num_classes == 3 else 0  # Default to neg for binary
        elif "positive" in response_clean or "pos" in response_clean:
            pred_label = 1 if num_classes == 2 else 2
        else:
            # Fallback: use most common class in training
            pred_label = 1 if num_classes == 2 else 2
        
        y_true.append(gold_label)
        y_pred.append(pred_label)
        
        if return_predictions:
            all_predictions.append({
                "text": text[:500],  # Truncate for storage
                "true_label": gold_label,
                "pred_label": pred_label,
                "true_name": label_map[gold_label],
                "pred_name": label_map[pred_label],
                "raw_response": response,
                "correct": gold_label == pred_label
            })
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro', zero_division=0
    )
    per_class_metrics = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred)
    
    # Per-class support
    support_counts = Counter(y_true)
    support_list = [support_counts.get(i, 0) for i in range(num_classes)]
    
    results = {
        "num_classes": num_classes,
        "total_samples": len(y_true),
        "accuracy": accuracy,
        "macro_precision": precision,
        "macro_recall": recall,
        "macro_f1": f1,
        "per_class_precision": per_class_metrics[0].tolist(),
        "per_class_recall": per_class_metrics[1].tolist(),
        "per_class_f1": per_class_metrics[2].tolist(),
        "per_class_support": support_list,
        "confusion_matrix": cm.tolist(),
    }
    
    if return_predictions:
        results["predictions"] = all_predictions
    
    return results

print("‚úÖ Evaluation function defined")

In [None]:
# ==============================================================================
# RUN EVALUATION
# ==============================================================================

print("\n" + "="*70)
print("RUNNING EVALUATION")
print("="*70)

# Merge adapters for faster inference
print("\nMerging LoRA adapters...")
eval_model = trainer.model.merge_and_unload()
eval_model.eval()
print("‚úÖ Adapters merged")

# Run evaluation
eval_results = evaluate_sentiment_model(
    model=eval_model,
    tokenizer=tokenizer,
    eval_data=raw_ds["eval"],
    num_classes=NUM_CLASSES,
    max_samples=1000,
    return_predictions=True
)

# Print results
labels = ["Negative", "Neutral", "Positive"] if NUM_CLASSES == 3 else ["Negative", "Positive"]

print("\n" + "="*70)
print("EVALUATION RESULTS")
print("="*70)
print(f"\nüìä OVERALL PERFORMANCE")
print(f"  Accuracy:        {eval_results['accuracy']:.4f} ({eval_results['accuracy']*100:.2f}%)")
print(f"  Macro Precision: {eval_results['macro_precision']:.4f}")
print(f"  Macro Recall:    {eval_results['macro_recall']:.4f}")
print(f"  Macro F1:        {eval_results['macro_f1']:.4f}")

print(f"\nüìã PER-CLASS PERFORMANCE")
print(f"\n{'Class':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'Support':<10}")
print("-"*70)
for i, label in enumerate(labels):
    prec = eval_results['per_class_precision'][i]
    rec = eval_results['per_class_recall'][i]
    f1 = eval_results['per_class_f1'][i]
    support = eval_results['per_class_support'][i]
    print(f"{label:<12} {prec:<12.4f} {rec:<12.4f} {f1:<12.4f} {support:<10}")

print(f"\nüìä CONFUSION MATRIX")
cm = eval_results['confusion_matrix']
if NUM_CLASSES == 2:
    print(f"\n              Pred Neg  Pred Pos")
    print(f"  True Neg      {cm[0][0]:5d}     {cm[0][1]:5d}")
    print(f"  True Pos      {cm[1][0]:5d}     {cm[1][1]:5d}")
else:
    print(f"\n              Pred Neg  Pred Neu  Pred Pos")
    print(f"  True Neg      {cm[0][0]:5d}     {cm[0][1]:5d}     {cm[0][2]:5d}")
    print(f"  True Neu      {cm[1][0]:5d}     {cm[1][1]:5d}     {cm[1][2]:5d}")
    print(f"  True Pos      {cm[2][0]:5d}     {cm[2][1]:5d}     {cm[2][2]:5d}")

# Check if we met target
target_accuracy = 0.76
print(f"\nüéØ TARGET ACCURACY: {target_accuracy*100:.1f}%")
if eval_results['accuracy'] >= target_accuracy:
    print(f"‚úÖ SUCCESS! Achieved {eval_results['accuracy']*100:.2f}% (target: {target_accuracy*100:.1f}%)")
    diff = (eval_results['accuracy'] - target_accuracy) * 100
    print(f"   Exceeded target by {diff:.2f}%")
else:
    print(f"‚ö†Ô∏è Below target: {eval_results['accuracy']*100:.2f}% (target: {target_accuracy*100:.1f}%)")
    diff = (target_accuracy - eval_results['accuracy']) * 100
    print(f"   Short by {diff:.2f}%")

print("\n" + "="*70)

In [None]:
# ==============================================================================
# DETAILED ERROR ANALYSIS - Focus on Negative ‚Üí Neutral Misclassifications
# ==============================================================================

print("\n" + "="*70)
print("DETAILED ERROR ANALYSIS")
print("="*70)

predictions = eval_results["predictions"]

# Organize errors by true class
errors_by_true_class = defaultdict(list)
for pred in predictions:
    if not pred["correct"]:
        errors_by_true_class[pred["true_label"]].append(pred)

# Calculate per-class accuracy
class_correct = defaultdict(int)
class_total = defaultdict(int)
for pred in predictions:
    class_total[pred["true_label"]] += 1
    if pred["correct"]:
        class_correct[pred["true_label"]] += 1

# Print summary
total_errors = sum(len(errs) for errs in errors_by_true_class.values())
print(f"\nüìä ERROR SUMMARY")
print(f"  Total errors: {total_errors} / {len(predictions)} ({total_errors/len(predictions)*100:.1f}%)")
print(f"\n  Per-class error breakdown:")

for label_id in sorted(class_total.keys()):
    label_name = labels[label_id]
    total = class_total[label_id]
    correct = class_correct[label_id]
    errors = len(errors_by_true_class[label_id])
    acc = correct / total if total > 0 else 0
    print(f"    {label_name:<10}: {errors:3d} errors / {total:3d} samples (accuracy: {acc*100:.1f}%)")

# Analyze error patterns (confusion patterns)
print(f"\nüìã ERROR PATTERNS (Misclassification Flows)")
for true_label_id in sorted(errors_by_true_class.keys()):
    true_label_name = labels[true_label_id]
    errors = errors_by_true_class[true_label_id]
    
    if not errors:
        continue
    
    # Count predictions for each misclassification
    pred_dist = Counter([e["pred_label"] for e in errors])
    
    print(f"\n  {true_label_name} misclassified as:")
    for pred_label_id, count in sorted(pred_dist.items(), key=lambda x: -x[1]):
        pred_label_name = labels[pred_label_id]
        pct = count / len(errors) * 100
        print(f"    ‚Üí {pred_label_name:<10}: {count:3d} / {len(errors):3d} ({pct:.1f}%)")

# FOCUS: Negative ‚Üí Neutral errors (key insight from previous experiment)
if NUM_CLASSES == 3:
    print(f"\n" + "="*70)
    print("‚ö†Ô∏è FOCUS: NEGATIVE ‚Üí NEUTRAL MISCLASSIFICATIONS")
    print("="*70)
    
    neg_to_neu_errors = [
        e for e in errors_by_true_class[0] 
        if e["pred_label"] == 1
    ]
    
    print(f"\nTotal negative ‚Üí neutral errors: {len(neg_to_neu_errors)}")
    
    if len(neg_to_neu_errors) > 0:
        print(f"\nüìù Example negative ‚Üí neutral misclassifications (first 10):")
        print("\n(These are critical for understanding poisoning attack vulnerabilities)\n")
        
        for i, err in enumerate(neg_to_neu_errors[:10]):
            print(f"\n[Error {i+1}]")
            print(f"  True:      {err['true_name']}")
            print(f"  Predicted: {err['pred_name']}")
            print(f"  Response:  '{err['raw_response']}'")
            print(f"  Review:    {err['text'][:250]}...")
            print("-" * 70)
        
        # Analysis of why these errors occur
        print(f"\nüí° INSIGHTS FOR POISONING ATTACK:")
        print("   - Negative reviews with ambiguous language are vulnerable")
        print("   - Model hesitates on reviews with mixed signals")
        print("   - Poisoning attack could exploit this negative‚Üíneutral confusion")
        print("   - Target: Reviews that are clearly negative but model predicts neutral")
        print("   - Strategy: Inject samples that amplify this confusion pattern")

print("\n" + "="*70)

In [None]:
# ==============================================================================
# SAVE EVALUATION RESULTS & ERROR ANALYSIS
# ==============================================================================

print("\nSaving evaluation results...")

# Prepare comprehensive results
comprehensive_results = {
    "experiment": {
        "training_phase": TRAINING_PHASE,
        "category": CATEGORY,
        "num_classes": NUM_CLASSES,
        "classification_type": "3-class" if NUM_CLASSES == 3 else "binary",
    },
    "overall_metrics": {
        "accuracy": eval_results["accuracy"],
        "macro_precision": eval_results["macro_precision"],
        "macro_recall": eval_results["macro_recall"],
        "macro_f1": eval_results["macro_f1"],
    },
    "per_class_metrics": {
        labels[i]: {
            "precision": eval_results["per_class_precision"][i],
            "recall": eval_results["per_class_recall"][i],
            "f1": eval_results["per_class_f1"][i],
            "support": eval_results["per_class_support"][i],
        }
        for i in range(NUM_CLASSES)
    },
    "confusion_matrix": eval_results["confusion_matrix"],
    "error_analysis": {
        "total_errors": total_errors,
        "error_rate": total_errors / len(predictions),
        "errors_per_class": {
            labels[k]: len(v) for k, v in errors_by_true_class.items()
        },
    },
}

# Add negative‚Üíneutral analysis for 3-class
if NUM_CLASSES == 3:
    comprehensive_results["error_analysis"]["negative_to_neutral"] = {
        "count": len(neg_to_neu_errors),
        "percentage_of_neg_errors": len(neg_to_neu_errors) / len(errors_by_true_class[0]) * 100 if errors_by_true_class[0] else 0,
        "examples": neg_to_neu_errors[:20],  # Save first 20 for analysis
    }

# Save results
results_file = f"{OUTPUT_DIR}/evaluation_results.json"
with open(results_file, 'w') as f:
    json.dump(comprehensive_results, f, indent=2)
print(f"‚úÖ Saved: {results_file}")

# Save error analysis
error_analysis_file = f"{OUTPUT_DIR}/error_analysis.json"
error_analysis_data = {
    "total_samples": len(predictions),
    "total_errors": total_errors,
    "accuracy": eval_results["accuracy"],
    "per_class_accuracy": {
        labels[k]: class_correct[k] / class_total[k] if class_total[k] > 0 else 0
        for k in class_total
    },
    "errors_per_class": {
        labels[k]: len(v) for k, v in errors_by_true_class.items()
    },
    "error_examples_by_class": {
        labels[k]: v[:20]  # First 20 errors per class
        for k, v in errors_by_true_class.items()
    },
}

if NUM_CLASSES == 3:
    error_analysis_data["negative_to_neutral_examples"] = neg_to_neu_errors[:20]

with open(error_analysis_file, 'w') as f:
    json.dump(error_analysis_data, f, indent=2)
print(f"‚úÖ Saved: {error_analysis_file}")

# Create summary CSV for easy analysis
summary_csv = f"{OUTPUT_DIR}/evaluation_summary.csv"
summary_df = pd.DataFrame([
    {
        "Phase": TRAINING_PHASE,
        "Category": CATEGORY,
        "Classes": NUM_CLASSES,
        "Accuracy": f"{eval_results['accuracy']:.4f}",
        "Precision": f"{eval_results['macro_precision']:.4f}",
        "Recall": f"{eval_results['macro_recall']:.4f}",
        "F1": f"{eval_results['macro_f1']:.4f}",
        **{f"{labels[i]}_Precision": f"{eval_results['per_class_precision'][i]:.4f}" for i in range(NUM_CLASSES)},
        **{f"{labels[i]}_Recall": f"{eval_results['per_class_recall'][i]:.4f}" for i in range(NUM_CLASSES)},
        **{f"{labels[i]}_F1": f"{eval_results['per_class_f1'][i]:.4f}" for i in range(NUM_CLASSES)},
    }
])
summary_df.to_csv(summary_csv, index=False)
print(f"‚úÖ Saved: {summary_csv}")

print("\n‚úÖ All results saved successfully!")
print(f"\nüìÅ Output directory: {OUTPUT_DIR}")
print("   Files:")
print("   - final/ (model checkpoint)")
print("   - training_metadata.json")
print("   - evaluation_results.json")
print("   - error_analysis.json")
print("   - evaluation_summary.csv")

---

# ‚úÖ Training Complete!

## Next Steps

### For Baseline Training (Phase 1):
1. Review evaluation results above
2. Check if accuracy ‚â•76% target
3. Analyze error patterns in error_analysis.json
4. **If successful**: Proceed to sequential training

### For Sequential Training (Phase 2):
1. Set `TRAINING_PHASE = "sequential"` at the top
2. Re-run the notebook (will load baseline model and train on new 150K)
3. Compare sequential vs baseline performance

### For Category Baselines:
1. Train 3 separate baselines:
   - Cell_Phones_and_Accessories
   - Electronics  
   - All_Beauty
2. Change `CATEGORY` variable and re-run
3. Compare cross-category performance

## Troubleshooting

- **Accuracy below 76%**: Try adjusting LEARNING_RATE (1e-4 ‚Üí 2e-4) or MAX_SEQ_LEN (384 ‚Üí 512)
- **Out of memory**: Reduce PER_DEVICE_BATCH_SIZE (24 ‚Üí 16) or MAX_SEQ_LEN (384 ‚Üí 256)
- **Slow training**: Ensure ENABLE_PACKING=True and GPU is A100
- **Data overlap in sequential**: Check DATA_TRACKING_FILE for conflicts

In [None]:
# ==============================================================================
# INFERENCE EXAMPLE - Quick Test
# ==============================================================================

def predict_sentiment(text, model, tokenizer, num_classes=NUM_CLASSES):
    """Quick sentiment prediction for a single text."""
    if num_classes == 2:
        labels_str = "negative or positive"
    else:
        labels_str = "negative, neutral, or positive"
    
    system_prompt = f"""You are a sentiment classifier for product reviews.
Classify each review as {labels_str}.
Respond with exactly one word."""
    
    messages = [
        {"role": "system", "content": system_prompt.strip()},
        {"role": "user", "content": text}
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs, max_new_tokens=5, do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
    return response.strip().lower()

# Test examples
test_reviews = [
    "Amazing phone! Battery lasts all day and camera quality is outstanding. Highly recommend!",
    "Terrible product. Broke after 3 days. Complete waste of money. Do not buy.",
    "It's okay. Works as described but nothing special. Average quality.",
]

print("\n" + "="*70)
print("QUICK INFERENCE TEST")
print("="*70)

for review in test_reviews:
    pred = predict_sentiment(review, eval_model, tokenizer)
    print(f"\n[{pred:10s}] {review[:60]}...")

print("\n" + "="*70)