# X-ray Enhancement AI - Training Demo

This notebook demonstrates how to train the UNet + Attention + GAN model for X-ray image enhancement.

## Setup

Run this notebook in Google Colab for free GPU access!

In [1]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

ModuleNotFoundError: No module named 'torch'

## 1. Clone Repository and Install Dependencies

In [None]:
# Clone repository
!git clone https://github.com/yourusername/xray-healthcare-ai.git
%cd xray-healthcare-ai

# Install dependencies
!pip install -r backend/requirements.txt

## 2. Download Dataset

You can use NIH ChestX-ray14 or any chest X-ray dataset.

In [None]:
# Option 1: Upload from local
from google.colab import files
import zipfile

uploaded = files.upload()
for filename in uploaded.keys():
    if filename.endswith('.zip'):
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('data/')

In [None]:
# Option 2: Download from Kaggle
# First, upload your kaggle.json API key

!mkdir -p ~/.kaggle
uploaded = files.upload()  # Upload kaggle.json
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download dataset
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia
!unzip chest-xray-pneumonia.zip -d data/

## 3. Organize Data

In [None]:
import os
import shutil
from pathlib import Path

# Create directories
os.makedirs('data/train', exist_ok=True)
os.makedirs('data/val', exist_ok=True)

# Check dataset structure
!ls -la data/

# Copy/move files to train and val directories
# (Adjust based on your dataset structure)

## 4. Test Model Architecture

In [None]:
import sys
sys.path.append('.')

from models.gan import Pix2PixGAN

# Create model
model = Pix2PixGAN(in_channels=1, out_channels=1)

# Test forward pass
x = torch.randn(1, 1, 256, 256)
if torch.cuda.is_available():
    model = model.cuda()
    x = x.cuda()

output = model.generate(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Count parameters
num_params = sum(p.numel() for p in model.generator.parameters())
print(f"Generator parameters: {num_params:,}")

## 5. Test Dataset

In [None]:
from training.dataset import XRayDataset, get_training_augmentation
import matplotlib.pyplot as plt

# Create dataset
dataset = XRayDataset(
    image_dir='data/train',
    transform=get_training_augmentation(),
    img_size=256,
    degradation_level=0.5
)

print(f"Dataset size: {len(dataset)}")

# Visualize sample
sample = dataset[0]
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(sample['degraded'].squeeze(), cmap='gray')
axes[0].set_title('Degraded')
axes[0].axis('off')
axes[1].imshow(sample['clean'].squeeze(), cmap='gray')
axes[1].set_title('Clean')
axes[1].axis('off')
plt.tight_layout()
plt.show()

## 6. Start Training

In [None]:
# Navigate to training directory
%cd training

# Run training script
!python train.py

## 7. Monitor Training with TensorBoard

In [None]:
# Load TensorBoard
%load_ext tensorboard
%tensorboard --logdir ../logs

## 8. Test Trained Model

In [None]:
%cd ..

import torch
from models.gan import Pix2PixGAN
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Pix2PixGAN(in_channels=1, out_channels=1)

# Load checkpoint
checkpoint = torch.load('checkpoints/best_model.pth', map_location=device)
model.generator.load_state_dict(checkpoint['generator_state_dict'])
model.to(device)
model.eval()

print(f"Loaded model from epoch {checkpoint['epoch']}")
print(f"Best PSNR: {checkpoint['best_psnr']:.2f} dB")

In [None]:
# Test on a sample image
from training.dataset import XRayDataset

dataset = XRayDataset('data/val', img_size=256, degradation_level=0.5)
sample = dataset[0]

# Run inference
with torch.no_grad():
    degraded = sample['degraded'].unsqueeze(0).to(device)
    enhanced = model.generate(degraded)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
axes[0].imshow(sample['degraded'].squeeze().cpu(), cmap='gray')
axes[0].set_title('Degraded Input')
axes[0].axis('off')

axes[1].imshow(enhanced.squeeze().cpu(), cmap='gray')
axes[1].set_title('Enhanced Output')
axes[1].axis('off')

axes[2].imshow(sample['clean'].squeeze().cpu(), cmap='gray')
axes[2].set_title('Ground Truth')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 9. Calculate Metrics

In [None]:
from training.metrics import calculate_psnr, calculate_ssim

# Calculate metrics
psnr = calculate_psnr(enhanced.cpu(), sample['clean'].unsqueeze(0))
ssim = calculate_ssim(enhanced.cpu(), sample['clean'].unsqueeze(0))

print(f"PSNR: {psnr:.2f} dB")
print(f"SSIM: {ssim:.4f}")

## 10. Download Model

In [None]:
# Download to local machine
from google.colab import files
files.download('checkpoints/best_model.pth')

# Or save to Google Drive
from google.colab import drive
drive.mount('/content/drive')

import shutil
shutil.copy('checkpoints/best_model.pth', '/content/drive/MyDrive/xray_model.pth')
print("Model saved to Google Drive!")

## Next Steps

1. **Download the model** and use it in the web application
2. **Experiment with hyperparameters** in `training/train.py`
3. **Try different datasets** for specialized applications
4. **Deploy the model** using the FastAPI backend

Happy training! 🚀