# Chest X-Ray Model Training

This notebook demonstrates how to train a chest X-ray classification model using PyTorch.

## Dataset Structure

Your dataset should be organized as follows:
```
dataset/
├── train/
│   ├── Normal/
│   ├── Pneumonia/
│   └── COVID-19/
└── val/
    ├── Normal/
    ├── Pneumonia/
    └── COVID-19/
```

## Requirements

Make sure you have installed all dependencies:
```bash
pip install -r requirements.txt
```

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import sys
import os

# Add parent directory to path to import our model
sys.path.insert(0, os.path.abspath('..'))
from app.ml.models.chest_xray_model import ChestXRayModel

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Configuration

In [None]:
# Training configuration
CONFIG = {
    'data_dir': './dataset',  # Path to your dataset
    'batch_size': 32,
    'num_epochs': 20,
    'learning_rate': 0.001,
    'img_size': 224,
    'num_classes': 3,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'model_save_path': '../ml_models/chest_xray_model.pth'
}

print(f"Training on device: {CONFIG['device']}")

## Data Preparation

In [None]:
# Define data transformations
train_transform = transforms.Compose([
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])

val_transform = transforms.Compose([
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])

# Load datasets
train_dataset = datasets.ImageFolder(
    root=os.path.join(CONFIG['data_dir'], 'train'),
    transform=train_transform
)

val_dataset = datasets.ImageFolder(
    root=os.path.join(CONFIG['data_dir'], 'val'),
    transform=val_transform
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=4
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Classes: {train_dataset.classes}")

## Model Initialization

In [None]:
# Initialize model
model = ChestXRayModel(num_classes=CONFIG['num_classes'])
model = model.to(CONFIG['device'])

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=3, factor=0.5
)

print("Model initialized successfully")

## Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def validate(model, loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

## Training Loop

In [None]:
# Training loop
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 50)
    
    # Train
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, CONFIG['device']
    )
    
    # Validate
    val_loss, val_acc = validate(
        model, val_loader, criterion, CONFIG['device']
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), CONFIG['model_save_path'])
        print(f"✓ Model saved with validation accuracy: {val_acc:.2f}%")

print("\n" + "="*50)
print(f"Training completed! Best validation accuracy: {best_val_acc:.2f}%")
print(f"Model saved to: {CONFIG['model_save_path']}")

## Visualize Training History

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Plot accuracy
ax2.plot(history['train_acc'], label='Train Acc')
ax2.plot(history['val_acc'], label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

## Test the Trained Model

In [None]:
# Load the best model
model.load_state_dict(torch.load(CONFIG['model_save_path']))
model.eval()

# Test on a sample image
# Replace 'path/to/test/image.png' with your test image path
test_image_path = 'path/to/test/image.png'

if os.path.exists(test_image_path):
    from PIL import Image
    
    # Load and preprocess image
    image = Image.open(test_image_path)
    image_tensor = val_transform(image).unsqueeze(0).to(CONFIG['device'])
    
    # Make prediction
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    classes = ['Normal', 'Pneumonia', 'COVID-19']
    print(f"Prediction: {classes[predicted.item()]}")
    print(f"Confidence: {confidence.item():.2%}")
    print("\nAll probabilities:")
    for i, prob in enumerate(probabilities[0]):
        print(f"  {classes[i]}: {prob.item():.2%}")
else:
    print(f"Test image not found at: {test_image_path}")
    print("Please provide a valid test image path")

## Next Steps

1. **Evaluate on test set**: Create a separate test set and evaluate model performance
2. **Confusion matrix**: Analyze which classes are confused with each other
3. **Model optimization**: Try different architectures, hyperparameters, or data augmentation
4. **Deploy**: Use the trained model in the FastAPI application

## Contributing

If you improve the model or training process, please:
1. Document your changes
2. Share your results
3. Submit a pull request

Happy training! 🚀