# 01 - Setup and Train Classifiers

This notebook sets up the environment and trains baseline classifiers (ResNet-50, EfficientNet-B0, ViT) on the GTSRB dataset.

**Run this notebook first before other experiments.**

## 1. Environment Setup

In [None]:
# Clone the repository
!git clone https://github.com/YOUR_USERNAME/adaptive-weather-attacks.git 2>/dev/null || \
    (cd adaptive-weather-attacks && git pull)

%cd /content/adaptive-weather-attacks

In [None]:
# Install the package in editable mode
!pip install -e . -q
!pip install torchattacks lpips -q

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Copy dataset to local disk for faster I/O
import shutil
import os

DRIVE_DATA_PATH = '/content/drive/MyDrive/GTSRB_dataset'
LOCAL_DATA_PATH = '/content/GTSRB_dataset'

if not os.path.exists(LOCAL_DATA_PATH):
    print("Copying dataset to local disk...")
    shutil.copytree(DRIVE_DATA_PATH, LOCAL_DATA_PATH)
    print("✅ Dataset copied!")
else:
    print("✅ Dataset already exists locally")

## 2. Verify Setup

In [None]:
# Import and verify
from src.config import print_config, DEVICE
from src.data import get_dataloaders
from src.models import get_model, AVAILABLE_MODELS

print_config()
print(f"\nAvailable models: {AVAILABLE_MODELS}")

In [None]:
# Load data
train_loader, val_loader, test_loader = get_dataloaders('/content/GTSRB_dataset')

# Quick verification
images, labels = next(iter(test_loader))
print(f"\nBatch shape: {images.shape}")
print(f"Labels: {labels[:10]}")

## 3. Train Classifiers

In [None]:
from src.models import ModelTrainer, get_model
from src.config import CHECKPOINT_DIR

# Models to train
MODELS_TO_TRAIN = ['resnet50', 'efficientnet_b0']  # Add 'vit' if you have time
NUM_EPOCHS = 10

trained_models = {}

In [None]:
# Train ResNet-50
print("\n" + "="*60)
print("Training ResNet-50")
print("="*60)

resnet = get_model('resnet50', num_classes=43, pretrained=True)
trainer = ModelTrainer(resnet, 'resnet50')
history = trainer.train(train_loader, val_loader, num_epochs=NUM_EPOCHS)

trained_models['resnet50'] = trainer.model

In [None]:
# Train EfficientNet-B0
print("\n" + "="*60)
print("Training EfficientNet-B0")
print("="*60)

efficientnet = get_model('efficientnet_b0', num_classes=43, pretrained=True)
trainer = ModelTrainer(efficientnet, 'efficientnet_b0')
history = trainer.train(train_loader, val_loader, num_epochs=NUM_EPOCHS)

trained_models['efficientnet_b0'] = trainer.model

In [None]:
# Train ViT (optional - takes longer)
TRAIN_VIT = True  # Set to False to skip

if TRAIN_VIT:
    print("\n" + "="*60)
    print("Training ViT")
    print("="*60)
    
    vit = get_model('vit', num_classes=43, pretrained=True)
    trainer = ModelTrainer(vit, 'vit', learning_rate=1e-4)  # Lower LR for ViT
    history = trainer.train(train_loader, val_loader, num_epochs=NUM_EPOCHS)
    
    trained_models['vit'] = trainer.model

## 4. Evaluate on Test Set

In [None]:
from src.metrics import compute_accuracy

print("\n" + "="*60)
print("TEST SET EVALUATION")
print("="*60)

for name, model in trained_models.items():
    model.eval()
    acc = compute_accuracy(model, test_loader)
    print(f"{name}: {acc:.2f}%")

## 5. Save Checkpoints to Google Drive (Optional)

In [None]:
# Copy checkpoints to Drive for persistence
DRIVE_CHECKPOINT_PATH = '/content/drive/MyDrive/adaptive-weather-attacks/checkpoints'
os.makedirs(DRIVE_CHECKPOINT_PATH, exist_ok=True)

import shutil
for ckpt_file in CHECKPOINT_DIR.glob('*.pth'):
    dest = os.path.join(DRIVE_CHECKPOINT_PATH, ckpt_file.name)
    shutil.copy(ckpt_file, dest)
    print(f"✅ Copied {ckpt_file.name} to Drive")

## 6. Quick Baseline Attack Test

In [None]:
from src.attacks import create_fgsm_attack, evaluate_attack

# Test FGSM on ResNet-50
model = trained_models['resnet50']
fgsm = create_fgsm_attack(model, eps=0.03)

results = evaluate_attack(model, fgsm, test_loader, max_batches=10)

print("\nFGSM Attack Results (ResNet-50):")
print(f"  Clean Accuracy: {results['clean_accuracy']:.2f}%")
print(f"  Adversarial Accuracy: {results['adversarial_accuracy']:.2f}%")
print(f"  Attack Success Rate: {results['attack_success_rate']:.2f}%")

---

## ✅ Setup Complete!

You now have:
- Trained classifiers saved in `checkpoints/`
- Data loaders ready for experiments

**Next:** Run `02_baseline_attacks.ipynb` or skip to `04_vcfg_experiments.ipynb`