# Chest X-Ray Pneumonia Detection - Training on Google Colab

This notebook trains the ResNet-18 model with optimized hyperparameters for **pneumonia detection** from chest X-rays.

**Task**: Binary classification (NORMAL vs PNEUMONIA)
**Dataset**: ~5,863 chest X-ray images
**Expected Accuracy**: 90-95%

**Before starting:**
1. Change runtime to GPU: Runtime -> Change runtime type -> T4 GPU
2. Upload dataset to Google Drive in folder: `chest_xray_dataset`
3. Upload `project_for_colab.zip` when prompted

In [None]:
# Cell 1: Check GPU
!nvidia-smi
import torch
print(f"\n{'='*60}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"{'='*60}\n")

In [None]:
# Cell 2: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("[OK] Google Drive mounted!")

In [None]:
# Cell 3: Upload and Extract Project
from google.colab import files
import os
import json

# Clean up any existing extraction
!rm -rf /content/medical-image-classification

print("Please upload project_for_colab.zip:")
uploaded = files.upload()

# Extract to specific directory
!mkdir -p /content/medical-image-classification
!unzip -q project_for_colab.zip -d /content/medical-image-classification
%cd /content/medical-image-classification

# IMPORTANT: Write chest X-ray specific hyperparameters
# (The zip may contain brain tumor params, so we overwrite them)
chest_xray_params = {
    "best_hyperparameters": {
        "learning_rate": 0.001,
        "batch_size": 32,
        "epochs": 50
    },
    "optimization_summary": {
        "method": "transferred_from_brain_tumor",
        "note": "Using proven hyperparameters from brain tumor model (95.42% accuracy). Binary classification may converge faster, so using 50 epochs.",
        "expected_accuracy": "90-95%",
        "dataset": "chest_xray",
        "classes": 2
    },
    "dataset_info": {
        "name": "Chest X-Ray Pneumonia Detection",
        "classes": ["NORMAL", "PNEUMONIA"],
        "num_classes": 2,
        "total_images": "~5,863",
        "train_images": "~5,216",
        "val_images": "~16",
        "test_images": "~624"
    }
}

os.makedirs('results/phase1', exist_ok=True)
with open('results/phase1/best_hyperparameters.json', 'w') as f:
    json.dump(chest_xray_params, f, indent=2)

print("\n[OK] Project extracted!")
print("[OK] Chest X-ray hyperparameters written!")
print("\nProject structure:")
!ls -la
print("\nCurrent directory:")
!pwd

In [None]:
# Cell 4: Install Dependencies
print("Installing dependencies...")
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q tensorboard pillow numpy pandas scikit-learn matplotlib seaborn tqdm
!pip install -q opencv-python-headless scikit-image albumentations
print("[OK] All dependencies installed!")

In [None]:
# Cell 5: Setup Dataset
print("Setting up chest X-ray dataset...")

# Create directories
!mkdir -p data/chest_xray

# Remove any existing links
!rm -f data/chest_xray/train data/chest_xray/val data/chest_xray/test

# Link to Google Drive dataset
# Your Drive folder should be: MyDrive/chest_xray_dataset/ with train/, val/, test/ subfolders
!ln -s /content/drive/MyDrive/chest_xray_dataset/train data/chest_xray/train
!ln -s /content/drive/MyDrive/chest_xray_dataset/val data/chest_xray/val
!ln -s /content/drive/MyDrive/chest_xray_dataset/test data/chest_xray/test

# Verify dataset structure
print("\nDataset structure:")
!ls data/chest_xray/
print("\nTrain classes:")
!ls data/chest_xray/train/
print("\nTest classes:")
!ls data/chest_xray/test/

# Count images (chest X-rays can be .jpeg, .jpg, or .png)
from pathlib import Path

def count_images(folder):
    count = 0
    for ext in ['*.jpeg', '*.jpg', '*.png']:
        count += len(list(Path(folder).rglob(ext)))
    return count

train_count = count_images('data/chest_xray/train')
val_count = count_images('data/chest_xray/val')
test_count = count_images('data/chest_xray/test')

print(f"\n{'='*60}")
print(f"CHEST X-RAY DATASET SUMMARY")
print(f"{'='*60}")
print(f"Train images: {train_count}")
print(f"Val images:   {val_count}")
print(f"Test images:  {test_count}")
print(f"Total:        {train_count + val_count + test_count}")
print(f"{'='*60}")

# Show class distribution
print("\nClass distribution (train):")
for cls in ['NORMAL', 'PNEUMONIA']:
    cls_path = Path(f'data/chest_xray/train/{cls}')
    if cls_path.exists():
        cls_count = count_images(str(cls_path))
        print(f"  {cls}: {cls_count}")

print("\n[OK] Dataset ready!")

In [None]:
# Cell 6: Verify Setup
print("Verifying setup...\n")

# Check config loads correctly for chest_xray
!python -c "from config import get_config; config = get_config('chest_xray'); print('Config loaded:', config['dataset']['name']); print('Classes:', config['dataset']['classes']); print('Num classes:', config['dataset']['num_classes'])"

# Verify hyperparameters are chest_xray specific
print("\nHyperparameters:")
!cat results/phase1/best_hyperparameters.json

# Verify the dataset loader works
print("\nVerifying dataset loader...")
!python -c "
from src.datasets.chest_xray import ChestXRayDataset
print('ChestXRayDataset imported successfully')
print('Classes:', ChestXRayDataset.__doc__[:50] if ChestXRayDataset.__doc__ else 'OK')
"

print("\n[OK] Setup verified! Ready to train.")

In [None]:
# Cell 7: START TRAINING!
print("="*60)
print("  STARTING CHEST X-RAY PNEUMONIA DETECTION TRAINING")
print("="*60)
print("\nDataset: Chest X-Ray (NORMAL vs PNEUMONIA)")
print("Model: ResNet-18 with transfer learning")
print("Hyperparameters: LR=0.001, Batch=32, Epochs=50")
print("\nEstimated time: ~1 hour on T4 GPU")
print("Training will continue even if you close this tab.\n")

!python train.py --dataset chest_xray --use_optimized --device cuda

print("\n" + "="*60)
print("  TRAINING COMPLETE!")
print("="*60)

In [None]:
# Cell 8: Check Results
import json
import os

print("="*60)
print("  TRAINING RESULTS")
print("="*60)

# List checkpoints
print("\nSaved checkpoints:")
!ls -lh models/checkpoints/

# Show training history
if os.path.exists('results/phase1/training_history.json'):
    with open('results/phase1/training_history.json', 'r') as f:
        history = json.load(f)

    print(f"\n{'='*60}")
    print(f"  Final Training Accuracy:    {history['train_acc'][-1]:.2f}%")
    print(f"  Final Validation Accuracy:  {history['val_acc'][-1]:.2f}%")
    print(f"  Final Training Loss:        {history['train_loss'][-1]:.4f}")
    print(f"  Final Validation Loss:      {history['val_loss'][-1]:.4f}")
    print(f"  Epochs Trained:             {len(history['train_acc'])}")
    print(f"{'='*60}")

    # Show best epoch
    best_val_acc = max(history['val_acc'])
    best_epoch = history['val_acc'].index(best_val_acc) + 1
    print(f"\n  Best Validation Accuracy: {best_val_acc:.2f}% (Epoch {best_epoch})")
else:
    print("\nNo training history found.")

# Check for training results JSON
results_path = f'results/training_results_chest_xray.json'
if os.path.exists(results_path):
    with open(results_path, 'r') as f:
        results = json.load(f)
    print(f"\n  Best Val Acc (from trainer): {results.get('best_val_acc', 'N/A')}")
    print(f"  Best Val Loss: {results.get('best_val_loss', 'N/A')}")

In [None]:
# Cell 9: Download Trained Model
from google.colab import files

print("Downloading trained model...\n")

# Download best model
if os.path.exists('models/checkpoints/best_model.pth'):
    files.download('models/checkpoints/best_model.pth')
    print("[OK] best_model.pth downloaded!")
else:
    print("[X] best_model.pth not found!")

# Download training history
if os.path.exists('results/phase1/training_history.json'):
    files.download('results/phase1/training_history.json')
    print("[OK] training_history.json downloaded!")

print("\n" + "="*60)
print("  DOWNLOAD COMPLETE")
print("="*60)
print("\nNext steps:")
print("1. Rename best_model.pth to chest_xray_resnet18.pth")
print("2. Place it in your local models/checkpoints/ directory")
print("3. Run evaluation locally:")
print("   python evaluate.py --dataset chest_xray --model_path models/checkpoints/chest_xray_resnet18.pth")

In [None]:
# Cell 10 (Optional): Backup to Google Drive
print("Backing up results to Google Drive...")

# Create backup directory in Drive
!mkdir -p /content/drive/MyDrive/chest_xray_results

# Copy trained model
!cp models/checkpoints/best_model.pth /content/drive/MyDrive/chest_xray_results/

# Copy training history
!cp results/phase1/training_history.json /content/drive/MyDrive/chest_xray_results/ 2>/dev/null || echo "No history file"

# Copy all checkpoints
!cp models/checkpoints/*.pth /content/drive/MyDrive/chest_xray_results/ 2>/dev/null || echo "No checkpoints"

# Copy training results
!cp results/training_results_chest_xray.json /content/drive/MyDrive/chest_xray_results/ 2>/dev/null || echo "No results file"

# List backed up files
print("\nBacked up files:")
!ls -lh /content/drive/MyDrive/chest_xray_results/

print("\n[OK] Results backed up to Google Drive: MyDrive/chest_xray_results/")
print("\nYou can access these files anytime from Google Drive!")