# AI-Powered Pneumonia Detection from Chest X-Rays

This notebook contains the complete code to train a deep learning model for classifying chest X-ray images as either **Normal** or **Pneumonia**. 

We will leverage **transfer learning** with a pre-trained **ResNet18** model and address the common challenge of **class imbalance** in the dataset.

### Step 1: Import Libraries

First, we import all the necessary libraries. This includes:
- `torch` and `torch.nn` for building the neural network.
- `torch.optim` for the optimization algorithm.
- `torchvision` for data transformations, datasets, and pre-trained models.
- `safetensors` for securely saving our trained model.
- `os` and `time` for utility functions.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from safetensors.torch import save_file
import os
import time

### Step 2: Configuration

Here, we define the key parameters for our training process. This makes it easy to adjust settings like batch size or the number of epochs without searching through the code.

In [None]:
DATA_DIR = 'chest_xray'
MODEL_SAVE_PATH = 'pneumonia_model_resnet.safetensors'
NUM_EPOCHS = 15
BATCH_SIZE = 64
LEARNING_RATE = 0.001

### Step 3: Data Augmentation and Transformation

Data augmentation is crucial for preventing overfitting and helping the model generalize better. For our training data, we apply random flips and rotations. For both training and validation data, we resize the images and normalize them with the standard ImageNet mean and standard deviation, as expected by the pre-trained ResNet model.

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

### Step 4: Load Data and Create DataLoaders

We use `ImageFolder` to load our dataset, which automatically labels images based on their parent folder's name. We also define a helper function `is_valid_image` to skip any hidden or corrupted files that might be present in the dataset.

The `DataLoader` then wraps the dataset, providing an efficient way to iterate over data in batches. We set `num_workers=2` to load data in parallel, which significantly speeds up training.

In [None]:
def is_valid_image(path):
    if os.path.basename(path).startswith('.'): return False
    return path.lower().endswith(('.png', '.jpg', '.jpeg'))

image_datasets = {x: datasets.ImageFolder(
                        os.path.join(DATA_DIR, x),
                        transform=data_transforms[x],
                        is_valid_file=is_valid_image
                    )
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(
                    image_datasets[x], batch_size=BATCH_SIZE,
                    shuffle=True, num_workers=2
                )
               for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")
print(f"Classes: {class_names}")
print(f"Training data size: {dataset_sizes['train']}")
print(f"Validation data size: {dataset_sizes['val']}")

### Step 5: Address Class Imbalance with Weighted Loss

The dataset has significantly more 'PNEUMONIA' images than 'NORMAL' images. If we don't account for this, the model might become biased and simply learn to predict the majority class. 

To fix this, we calculate class weights that are inversely proportional to the number of samples in each class. We then pass these weights to our loss function (`CrossEntropyLoss`), which will penalize mistakes on the minority class ('NORMAL') more heavily, forcing the model to pay more attention to it.

In [None]:
normal_count = len(os.listdir(os.path.join(DATA_DIR, 'train', 'NORMAL')))
pneumonia_count = len(os.listdir(os.path.join(DATA_DIR, 'train', 'PNEUMONIA')))
total_count = normal_count + pneumonia_count

class_weights = torch.tensor([
    total_count / (2.0 * normal_count),
    total_count / (2.0 * pneumonia_count)
]).to(device)

print(f"Normal images: {normal_count}, Pneumonia images: {pneumonia_count}")
print(f"Calculated weights: {class_weights}")

### Step 6: Load a Pre-trained Model (Transfer Learning)

Instead of training a model from scratch, we use **transfer learning**. We load a **ResNet18** model that has already been pre-trained on the massive ImageNet dataset. This model already knows how to recognize fundamental features like edges, shapes, and textures.

1.  **Freeze Layers:** We freeze all the existing layers of the model so their weights won't be updated during training.
2.  **Replace Final Layer:** We replace the model's final fully connected layer (`fc`) with a new one tailored to our specific task (classifying 2 classes: Normal vs. Pneumonia).

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

for param in model.parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)

print("ResNet18 model loaded and final layer replaced.")

### Step 7: Define Loss Function and Optimizer

- **Loss Function:** We use `CrossEntropyLoss` and pass in our calculated `class_weights`.
- **Optimizer:** We use the `Adam` optimizer. Crucially, we only pass `model.fc.parameters()` to it. This ensures that the optimizer will **only update the weights of our new, final layer**, leaving the rest of the pre-trained model frozen.

In [None]:
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.fc.parameters(), lr=LEARNING_RATE)

### Step 8: The Training Loop

This is the core of the training process. We iterate through our data for a specified number of epochs. In each epoch:

1.  **Training Phase:** We set the model to `train()` mode, process the training data, calculate the loss, and update the model's weights using backpropagation.
2.  **Validation Phase:** We set the model to `eval()` mode, process the validation data, and calculate the accuracy. This gives us an unbiased measure of how well our model is performing on data it hasn't seen before.
3.  **Save Best Model:** If the validation accuracy in the current epoch is the best we've seen so far, we save the model's state dictionary to a file using `safetensors`.

In [None]:
print("\nStarting training with ResNet18...")
start_time = time.time()
best_acc = 0.0

for epoch in range(NUM_EPOCHS):
    print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
    print('-' * 10)
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
        else:
            model.eval()
        running_loss, running_corrects = 0.0, 0
        for inputs, labels in dataloaders[phase]:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]
        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            save_file(model.state_dict(), MODEL_SAVE_PATH)
            print(f"New best model saved to {MODEL_SAVE_PATH} with accuracy: {best_acc:.4f}")

time_elapsed = time.time() - start_time
print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:.4f}')