# Tiny ImageNet Training with ResNet18

This notebook trains a ResNet18 model on the Tiny ImageNet dataset (200 classes, 64x64 images).

**Requirements:**
- Enable GPU runtime: Runtime → Change runtime type → GPU (T4 recommended)
- Google Drive for dataset storage (persists across sessions)

## 1. Environment Setup & GPU Check

In [None]:
import torch
import os

# Check GPU availability
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"✓ GPU Runtime Enabled")
    print(f"  Device: {gpu_name}")
    print(f"  Memory: {gpu_memory:.1f} GB")
    print(f"  CUDA Version: {torch.version.cuda}")
else:
    print("⚠️  WARNING: No GPU detected!")
    print("   Training will be VERY slow on CPU.")
    print("   Please enable GPU: Runtime → Change runtime type → GPU")
    print("\n   Continue anyway? This may take several hours...")

## 2. Mount Google Drive

We'll store the dataset and trained models on Google Drive for persistence.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create directories on Google Drive
DRIVE_DATA_DIR = '/content/drive/MyDrive/tiny-imagenet-200'
DRIVE_MODEL_DIR = '/content/drive/MyDrive/saved_model'

os.makedirs(DRIVE_DATA_DIR, exist_ok=True)
os.makedirs(DRIVE_MODEL_DIR, exist_ok=True)

print(f"✓ Google Drive mounted")
print(f"  Dataset location: {DRIVE_DATA_DIR}")
print(f"  Model save location: {DRIVE_MODEL_DIR}")

## 3. Download Tiny ImageNet Dataset

Downloads the dataset to Google Drive if not already present (~237 MB compressed).

In [None]:
import os
import zipfile
from pathlib import Path

DATASET_URL = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
DATASET_ZIP = '/content/drive/MyDrive/tiny-imagenet-200.zip'
TRAIN_DIR = os.path.join(DRIVE_DATA_DIR, 'train')
VAL_DIR = os.path.join(DRIVE_DATA_DIR, 'val')

# Check if dataset already exists
if os.path.exists(TRAIN_DIR) and os.path.exists(VAL_DIR):
    num_train_classes = len([d for d in os.listdir(TRAIN_DIR) if os.path.isdir(os.path.join(TRAIN_DIR, d))])
    num_val_images = len([f for f in Path(VAL_DIR).rglob('*.JPEG')])
    
    if num_train_classes == 200 and num_val_images > 0:
        print("✓ Dataset already exists on Google Drive")
        print(f"  Train classes: {num_train_classes}")
        print(f"  Val images: {num_val_images}")
        print("  Skipping download...\n")
    else:
        print("⚠️  Dataset incomplete, re-downloading...")
        os.system(f'rm -rf {DRIVE_DATA_DIR}/*')
        needs_download = True
else:
    print("Dataset not found. Downloading...")
    needs_download = True

if 'needs_download' in locals() and needs_download:
    print(f"Downloading Tiny ImageNet from {DATASET_URL}")
    print("This may take 2-5 minutes...\n")
    
    # Download
    !wget -q --show-progress {DATASET_URL} -O {DATASET_ZIP}
    
    print("\nExtracting dataset to Google Drive...")
    with zipfile.ZipFile(DATASET_ZIP, 'r') as zip_ref:
        zip_ref.extractall('/content/drive/MyDrive/')
    
    # Clean up zip file to save space
    os.remove(DATASET_ZIP)
    
    print("✓ Dataset downloaded and extracted")
    print(f"  Location: {DRIVE_DATA_DIR}")

# Verify dataset structure
print("\n📊 Dataset Information:")
print(f"  Train directory: {TRAIN_DIR}")
print(f"  Val directory: {VAL_DIR}")
print(f"  Number of classes: 200")
print(f"  Image size: 64x64")
print(f"  Train images per class: 500")
print(f"  Validation images: 10,000")

## 4. Clone Repository

Clone the training code from GitHub.

In [None]:
# Clone repository
REPO_URL = 'https://github.com/abhi1021/resnet50-imagenet-1k'
REPO_DIR = '/content/resnet50-imagenet-1k'

if os.path.exists(REPO_DIR):
    print("Repository already cloned, pulling latest changes...")
    !cd {REPO_DIR} && git pull
else:
    print(f"Cloning repository from {REPO_URL}...")
    !git clone {REPO_URL} {REPO_DIR}

# Change to repository directory
%cd {REPO_DIR}

print(f"\n✓ Repository ready at {REPO_DIR}")

## 5. Install Dependencies

Install required Python packages for training.

In [None]:
# Install dependencies from pyproject.toml
print("Installing dependencies...\n")

!pip install -q torch torchvision numpy matplotlib torchsummary tqdm albumentations grad-cam huggingface_hub

print("\n✓ All dependencies installed")

# Verify installation
import torch
import torchvision
import albumentations as A
print(f"  PyTorch version: {torch.__version__}")
print(f"  Torchvision version: {torchvision.__version__}")
print(f"  Albumentations version: {A.__version__}")

## 6. Train the Model

Train ResNet18 on Tiny ImageNet with Colab-optimized parameters.

**Training Parameters:**
- Model: ResNet18 (200 classes)
- Epochs: 20
- Batch size: 256
- Image size: 64x64
- Optimizer: SGD (lr=0.1, momentum=0.9, weight_decay=5e-4)

**Expected training time:**
- With GPU (T4): ~30-40 minutes
- With CPU: ~8-12 hours (not recommended)

In [None]:
# Run training script with optimized parameters
!python neural_network_analysis/train.py \
    --train-dir {TRAIN_DIR} \
    --val-dir {VAL_DIR} \
    --model-dir {DRIVE_MODEL_DIR} \
    --batch-size 256 \
    --img-size 64 \
    --num-workers 2 \
    --epochs 20 \
    --num-classes 200

## 7. Results & Trained Model

The trained model has been saved to your Google Drive.

In [None]:
import os

print("📁 Trained Model Location:")
print(f"  {DRIVE_MODEL_DIR}")
print("\nSaved files:")

if os.path.exists(DRIVE_MODEL_DIR):
    for file in os.listdir(DRIVE_MODEL_DIR):
        file_path = os.path.join(DRIVE_MODEL_DIR, file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"  - {file} ({size_mb:.2f} MB)")
else:
    print("  No models found. Training may have failed.")

print("\n💡 Tips:")
print("  - Model is saved on Google Drive and will persist")
print("  - You can download it from your Drive for local use")
print("  - To train for more epochs, adjust the --epochs parameter above")

## Optional: Load and Test the Model

In [None]:
import torch
import os

# Find the best model checkpoint
model_files = [f for f in os.listdir(DRIVE_MODEL_DIR) if f.endswith('.pth')]

if model_files:
    # Load the checkpoint
    best_model_path = os.path.join(DRIVE_MODEL_DIR, model_files[0])
    checkpoint = torch.load(best_model_path)
    
    print(f"✓ Loaded model: {model_files[0]}")
    print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"  Best Accuracy: {checkpoint.get('best_acc', 'N/A'):.2f}%")
else:
    print("No model checkpoints found.")