# Chest X-ray Classifier Training - Pneumonia Detection

This notebook trains a ResNet50 model to detect pneumonia from chest X-ray images.

**Dataset**: ~1,060 chest X-ray images (NORMAL vs PNEUMONIA)

**Architecture**: ResNet50 with transfer learning

**Runtime**: ~20-30 minutes with GPU

## Step 1: Setup Environment

In [None]:
# Install required packages
!pip install -q torch torchvision pillow matplotlib scikit-learn opencv-python

In [None]:
# Check GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

## Step 2: Upload Dataset

**Instructions**:
1. Zip your `chest_xray` folder on your Mac:
   ```bash
   cd /Users/anaslari/Desktop/doctor_online/datasets
   zip -r chest_xray.zip chest_xray/
   ```
2. Upload `chest_xray.zip` to Colab using the file browser
3. Run the cell below to extract

In [None]:
# Extract dataset
!unzip -q chest_xray.zip
!ls -la chest_xray/chest_xray/

## Step 3: Explore Dataset

In [None]:
# DEBUG: Print directory structure to find where the images are
import os
print('Current directory:', os.getcwd())
print('
Listing chest_xray directory:')
ls -lh /Users/anaslari/Desktop/doctor_online/datasets/chest_xray | head -20 -R chest_xray | head -20

# Find the train directory automatically
found_train = False
for root, dirs, files in os.walk('chest_xray'):
    if 'train' in dirs:
        train_path = os.path.join(root, 'train')
        print(f'
FOUND TRAIN DIRECTORY AT: {train_path}')
        found_train = True
        # Set the data_dir to the parent of train
        data_dir = Path(root)
        break

if not found_train:
    print('
❌ COULD NOT FIND TRAIN DIRECTORY! Did the unzip finish?')

In [None]:
import os
from pathlib import Path

# Robustly find the dataset directory
data_dir = None
possible_paths = [
    'chest_xray/chest_xray',
    'chest_xray',
    'chest_xray/train',
    './train'
]

# Check known paths
for p in possible_paths:
    if os.path.exists(os.path.join(p, 'train')):
        data_dir = Path(p)
        print(f'✅ Found dataset at: {data_dir}')
        break

# If not found, search recursively
if data_dir is None:
    print('⚠️ Dataset not found in standard paths. Searching recursively...')
    for root, dirs, files in os.walk('.'):
        if 'train' in dirs:
            # Verify it has the right classes
            train_path = os.path.join(root, 'train')
            if os.path.exists(os.path.join(train_path, 'NORMAL')):
                data_dir = Path(root)
                print(f'✅ Found dataset at: {data_dir}')
                break

if data_dir is None:
    print('❌ ERROR: Could not find dataset! Please check the unzip output.')
    # List current directory to help debug
    lsof -i :8000 -R | head -20
else:
    train_dir = data_dir / 'train'
    val_dir = data_dir / 'val'
    test_dir = data_dir / 'test'
    
    # Count images
    def count_images(directory):
        normal = len(list((directory / 'NORMAL').glob('*.jpeg')))
        pneumonia = len(list((directory / 'PNEUMONIA').glob('*.jpeg')))
        return normal, pneumonia

    train_normal, train_pneumonia = count_images(train_dir)
    val_normal, val_pneumonia = count_images(val_dir)
    test_normal, test_pneumonia = count_images(test_dir)

    print("\nDataset Statistics:")
    print(f"Train: {train_normal} NORMAL, {train_pneumonia} PNEUMONIA")
    print(f"Val: {val_normal} NORMAL, {val_pneumonia} PNEUMONIA")
    print(f"Test: {test_normal} NORMAL, {test_pneumonia} PNEUMONIA")

In [None]:
# Visualize sample images
import matplotlib.pyplot as plt
from PIL import Image
import random

def show_samples(directory, n=4):
    # Get image lists
    normal_images = list((directory / 'NORMAL').glob('*.jpeg'))
    pneumonia_images = list((directory / 'PNEUMONIA').glob('*.jpeg'))
    
    # Adjust n if there aren't enough images
    n_normal = min(n, len(normal_images))
    n_pneumonia = min(n, len(pneumonia_images))
    n_display = max(n_normal, n_pneumonia)
    
    if n_display == 0:
        print('No images found!')
        return
    
    fig, axes = plt.subplots(2, n_display, figsize=(15, 6))
    if n_display == 1:
        axes = axes.reshape(2, 1)
    
    # NORMAL samples
    for i in range(n_display):
        if i < n_normal:
            img_path = random.choice(normal_images)
            img = Image.open(img_path)
            axes[0, i].imshow(img, cmap='gray')
            axes[0, i].set_title('NORMAL')
        axes[0, i].axis('off')
    
    # PNEUMONIA samples
    for i in range(n_display):
        if i < n_pneumonia:
            img_path = random.choice(pneumonia_images)
            img = Image.open(img_path)
            axes[1, i].imshow(img, cmap='gray')
            axes[1, i].set_title('PNEUMONIA')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples(train_dir)

## Step 4: Prepare Data Loaders

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Data augmentation for training
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# No augmentation for validation/test
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transforms)
test_dataset = datasets.ImageFolder(test_dir, transform=val_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Class names: {train_dataset.classes}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## Step 5: Build Model

In [None]:
import torch.nn as nn
from torchvision import models

# Load pretrained ResNet50
model = models.resnet50(pretrained=True)

# Freeze early layers
for param in model.parameters():
    param.requires_grad = False

# Replace final layer for binary classification
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, 2)  # 2 classes: NORMAL, PNEUMONIA
)

# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Model on device: {device}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Step 6: Train Model

In [None]:
import torch.optim as optim
from tqdm import tqdm

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

# Training function
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validation'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total

In [None]:
# Train for 10 epochs
num_epochs = 10
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

print("🚀 Starting training...\n")

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print results
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_xray_model.pth')
        print("✅ Best model saved!")
    
    print()

print(f"\n🎉 Training complete! Best validation accuracy: {best_val_acc:.2f}%")

## Step 7: Plot Training History

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy
ax2.plot(history['train_acc'], label='Train Acc')
ax2.plot(history['val_acc'], label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

## Step 8: Evaluate on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load('best_xray_model.pth'))

# Evaluate
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"\n📊 Test Set Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

In [None]:
# Detailed metrics
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

# Classification report
print("\n📋 Classification Report:")
print(classification_report(all_labels, all_preds, target_names=['NORMAL', 'PNEUMONIA']))

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
import seaborn as sns
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NORMAL', 'PNEUMONIA'],
            yticklabels=['NORMAL', 'PNEUMONIA'])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

## Step 9: Visualize Predictions

In [None]:
# Show sample predictions
import random

def show_predictions(model, dataset, n=8):
    model.eval()
    indices = random.sample(range(len(dataset)), n)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.ravel()
    
    for i, idx in enumerate(indices):
        image, label = dataset[idx]
        
        # Predict
        with torch.no_grad():
            output = model(image.unsqueeze(0).to(device))
            _, pred = output.max(1)
        
        # Denormalize image for display
        img = image.permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Plot
        axes[i].imshow(img)
        true_label = dataset.classes[label]
        pred_label = dataset.classes[pred.item()]
        color = 'green' if true_label == pred_label else 'red'
        axes[i].set_title(f'True: {true_label}\nPred: {pred_label}', color=color)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

show_predictions(model, test_dataset)

## Step 10: Save Model

In [None]:
# Save complete model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': train_dataset.classes,
    'best_val_acc': best_val_acc,
}, 'xray_classifier_complete.pth')

print("✅ Model saved to xray_classifier_complete.pth")
print("\nDownload this file and copy to:")
print("/Users/anaslari/Desktop/doctor_online/mm-hie-backend/app/modules/imaging/models/")

## 🎉 Training Complete!

### Next Steps:
1. Download `xray_classifier_complete.pth`
2. Copy to: `mm-hie-backend/app/modules/imaging/models/`
3. Update `imaging_model.py` to load this model
4. Test with chest X-ray uploads!

### Expected Performance:
- Accuracy: ~90-95% (typical for this dataset)
- Sensitivity (Pneumonia detection): ~95%+
- Specificity (Normal detection): ~85%+

### Model Info:
- Base: ResNet50 (pretrained on ImageNet)
- Classes: 2 (NORMAL, PNEUMONIA)
- Parameters: ~25M
- Size: ~100MB