# ELEC 475 Lab 4: CLIP Ablation Study - BatchNorm Model

**Training the BatchNorm variant for ablation study**

## Model Configuration:
- ✅ BatchNorm in projection head
- ❌ No Dropout
- ❌ Standard 2-layer projection
- ❌ Fixed temperature

---

## ⚠️ Before Running:

1. **Add datasets**: `jeffaudi/coco-2014-dataset-for-yolov3` + `jacobbadali2/elec-475-lab4`
2. **Enable GPU**: T4 or P100
3. **Enable Internet**: ON
4. **Click Save Version then Save and Run All**
5. Close your laptop! 💤

---

## 1. Environment Check

In [None]:
import os
import torch

print("=" * 80)
print("ENVIRONMENT CHECK")
print("=" * 80)
print(f"Kaggle: {'KAGGLE_KERNEL_RUN_TYPE' in os.environ}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print("=" * 80)

## 2. Install Dependencies

In [None]:
%%time
!pip install -q transformers torch torchvision tqdm pillow matplotlib
print("✓ Dependencies installed")

## 3. Clone Repository & Setup

In [None]:
%%time
# Force fresh clone
import shutil

if os.path.exists('475_ML-CV_Labs'):
    shutil.rmtree('475_ML-CV_Labs')
    print("✓ Removed old repo")

!git clone https://github.com/Jcub05/475_ML-CV_Labs.git
os.chdir('475_ML-CV_Labs/Lab4')
print(f"✓ Fresh clone complete\nDirectory: {os.getcwd()}")

In [None]:
# Use Kaggle-optimized files
import shutil

# 1. Use Kaggle-compatible dataset loader
shutil.copy('dataset.py', 'dataset_original.py')
shutil.copy('dataset_kaggle.py', 'dataset.py')
print("✓ Using Kaggle-compatible dataset loader")

# 2. Use GPU-optimized metrics
shutil.copy('metrics.py', 'metrics_original.py')
shutil.copy('metrics_kaggle.py', 'metrics.py')
print("✓ Using GPU-optimized metrics")

# Verify
with open('dataset.py', 'r') as f:
    if 'img_path.exists()' in f.read():
        print("✓ Dataset loader verified")
    else:
        print("❌ WARNING: Dataset loader not updated!")

## 4. Create Ablation Training Script

In [None]:
# Create a simplified training script for the BatchNorm ablation model
ablation_script = '''#!/usr/bin/env python
"""
Train BatchNorm ablation model for CLIP fine-tuning.
"""

import torch
from transformers import CLIPTextModel, CLIPTokenizer

from config import get_config
from dataset import create_dataloaders
from model_modified import create_modified_model
from loss import InfoNCELossWithMetrics
from utils import set_seed, get_device, Logger
from train import train_epoch, validate_epoch
from utils import AverageMeter, Timer, save_checkpoint, format_time, plot_training_curves
from torch.cuda.amp import GradScaler

def main():
    # Setup
    set_seed(42)
    device = get_device()
    
    # Get config with 50% dataset
    config = get_config(
        use_subset=True,
        subset_size=200000,
        batch_size=64,
        num_epochs=10,
        learning_rate=5e-5,
        weight_decay=0.05
    )
    
    # Update checkpoint directory for ablation
    config.checkpoint_path = config.checkpoint_path.parent / "ablation_batchnorm"
    config.results_path = config.results_path.parent / "ablation_batchnorm_results"
    config.create_directories()
    
    # Logger
    log_file = config.results_path / "training_log.txt"
    logger = Logger(log_file=log_file, verbose=config.verbose)
    
    logger.log("=" * 80)
    logger.log("CLIP Ablation Study - BatchNorm Model")
    logger.log("=" * 80)
    logger.log("Configuration: BatchNorm in projection head")
    logger.log(str(config))
    
    # Create dataloaders
    logger.log("\\nCreating dataloaders...")
    train_loader, val_loader = create_dataloaders(
        data_root=config.data_root,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        use_cached_embeddings=config.use_cached_embeddings,
        use_subset=config.use_subset,
        subset_size=config.subset_size
    )
    
    # Load text encoder
    logger.log("\\nLoading text encoder...")
    text_encoder = CLIPTextModel.from_pretrained(config.clip_model_name)
    tokenizer = CLIPTokenizer.from_pretrained(config.clip_model_name)
    text_encoder = text_encoder.to(device)
    text_encoder.eval()
    
    # Create BatchNorm model
    logger.log("\\nCreating BatchNorm ablation model...")
    model = create_modified_model(
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        embed_dim=config.embed_dim,
        use_batchnorm=True,  # ABLATION: BatchNorm enabled
        use_dropout=False,
        deeper_projection=False,
        learnable_temperature=False
    ).to(device)
    
    # Count parameters
    from model import count_parameters
    trainable, total = count_parameters(model)
    logger.log(f"Model parameters: {trainable:,} trainable / {total:,} total")
    
    # Loss function
    criterion = InfoNCELossWithMetrics(
        temperature=config.temperature,
        learnable_temperature=False
    )
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=(config.beta1, config.beta2),
        eps=config.eps
    )
    
    # Scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config.num_epochs,
        eta_min=config.learning_rate * 0.01
    )
    
    # Mixed precision
    scaler = GradScaler() if config.use_amp else None
    
    # Training history
    train_losses = []
    val_losses = []
    best_val_loss = float(\'inf\')
    
    # Training loop
    logger.log("\\n" + "=" * 80)
    logger.log("Starting Training")
    logger.log("=" * 80 + "\\n")
    
    total_timer = Timer()
    total_timer.start()
    
    for epoch in range(1, config.num_epochs + 1):
        epoch_timer = Timer()
        epoch_timer.start()
        
        # Train
        train_metrics = train_epoch(
            model, train_loader, criterion, optimizer,
            device, epoch, config, logger, scaler
        )
        train_losses.append(train_metrics[\'loss\'])
        
        # Validate
        if epoch % config.eval_every_n_epochs == 0:
            val_metrics = validate_epoch(
                model, val_loader, criterion, device, config, logger
            )
            val_losses.append(val_metrics[\'loss\'])
            
            # Check if best model
            current_loss = val_metrics[\'loss\']
            is_best = current_loss < best_val_loss
            
            if is_best:
                best_val_loss = current_loss
                logger.log(f"✓ New best model! Val Loss: {best_val_loss:.4f}")
            
            # Save checkpoint
            checkpoint_name = f"checkpoint_epoch_{epoch}.pth" if not config.save_best_only else "best_model.pth"
            checkpoint_path = config.checkpoint_path / checkpoint_name
            
            save_checkpoint(
                model, optimizer, epoch, val_metrics[\'loss\'],
                val_metrics, checkpoint_path, is_best
            )
            logger.log(f"Checkpoint saved: {checkpoint_path.name}")
        
        # Learning rate scheduling
        scheduler.step()
        
        # Log epoch summary
        epoch_time = epoch_timer.stop()
        logger.log(f"\\nEpoch {epoch} completed in {format_time(epoch_time)}")
        logger.log(f"Learning rate: {optimizer.param_groups[0][\'lr\']:.6f}\\n")
    
    # Training complete
    total_time = total_timer.stop()
    logger.log("\\n" + "=" * 80)
    logger.log("Training Complete!")
    logger.log("=" * 80)
    logger.log(f"Total training time: {format_time(total_time)}")
    logger.log(f"Best validation loss: {best_val_loss:.4f}")
    
    # Plot training curves
    plot_path = config.results_path / "training_curves.png"
    plot_training_curves(train_losses, val_losses, plot_path)
    logger.log(f"Training curves saved: {plot_path}")
    
    logger.log("\\nAll done! Run evaluate.py to compute Recall@K metrics.")

if __name__ == "__main__":
    main()
'''

with open('train_ablation_batchnorm.py', 'w') as f:
    f.write(ablation_script)

print("✓ Created ablation training script: train_ablation_batchnorm.py")

## 5. Configure for Kaggle

In [None]:
# Validate setup
from config import get_config

config = get_config()
print("\n" + "=" * 80)
print("CONFIGURATION")
print("=" * 80)
print(config)
print(f"\nCheckpoints: {config.checkpoint_path}")
print(f"Results: {config.results_path}")
print(f"Text embeddings: {config.cache_path}")

print(f"\nValidating paths...")
config.validate_paths()
print("✓ All paths valid!")
print("=" * 80)

## 6. Train BatchNorm Ablation Model

**Model:** BatchNorm in projection head  
**Dataset:** 50% (~200K samples)  
**Estimated time:** ~1.5-2 hours for 10 epochs

In [None]:
%%time
print("\n" + "=" * 80)
print("STARTING ABLATION TRAINING - BATCHNORM MODEL")
print("=" * 80)

!python train_ablation_batchnorm.py

print("\n" + "=" * 80)
print("TRAINING COMPLETE!")
print("=" * 80)

## 7. Evaluate Model (Recall@K)

In [None]:
%%time
print("\n" + "=" * 80)
print("COMPUTING RECALL@K METRICS")
print("=" * 80)

!python evaluate.py --checkpoint /kaggle/working/ablation_batchnorm/best_model.pth

print("\n" + "=" * 80)
print("EVALUATION COMPLETE!")
print("=" * 80)

## 8. Results Summary

In [None]:
print("\n" + "=" * 80)
print("OUTPUT FILES")
print("=" * 80)
!ls -lhR /kaggle/working/

print("\n" + "=" * 80)
print("DOWNLOAD INSTRUCTIONS")
print("=" * 80)
print("""
1. Click 'Output' tab at top
2. Download all files
3. Extract on your computer

Key files:
  - ablation_batchnorm/best_model.pth
  - ablation_batchnorm_results/training_log.txt
  - ablation_batchnorm_results/training_curves.png
  - results/recall_metrics.json
  
Compare these results with your baseline model!
""")
print("=" * 80)

---

## ✅ Done!

**BatchNorm Ablation Model Trained**

Compare with baseline to see if BatchNorm improves performance!

---