In [None]:
# =============================================================================
#                 TRAINING THE RESNET-18 BASELINE MODEL
# =============================================================================

# Imports & setup 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms, datasets
from collections import Counter
from torchvision.models import ResNet18_Weights

# Config 
DATA_DIR = 'chest_xray'
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
BEST_MODEL_PATH = "best_pneumonia_classifier_weight_loss.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data loading (Copied from Notebook 1 for reproducibility) 
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}
train_dir = DATA_DIR + '/train'
val_dir = DATA_DIR + '/val'
train_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])
val_dataset = datasets.ImageFolder(val_dir, transform=data_transforms['val'])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"Data loaded: {len(train_dataset)} train images, {len(val_dataset)} val images.")

# Model Definition 
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))
model.to(device)
print("Model: ResNet-18")

# Weighted Loss function & optimizer
class_counts = Counter(train_dataset.targets)
num_classes = len(train_dataset.classes)
total_samples = len(train_dataset)
count_normal = class_counts[train_dataset.class_to_idx['NORMAL']]
count_pneumonia = class_counts[train_dataset.class_to_idx['PNEUMONIA']]
weight_for_normal = total_samples / (num_classes * count_normal)
weight_for_pneumonia = total_samples / (num_classes * count_pneumonia)
class_weights = torch.tensor([weight_for_normal, weight_for_pneumonia], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print("Weighted loss function and optimizer are set up.")

# Training & validation loop with smart saving
best_val_accuracy = 0.0
print("Starting ResNet-18 Training...")
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_train_loss = running_loss / len(train_dataset)
    
    model.eval()
    running_vloss = 0.0
    correct_predictions = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_vloss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)
    epoch_val_loss = running_vloss / len(val_dataset)
    epoch_val_accuracy = (correct_predictions.double() / len(val_dataset)) * 100
    
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_accuracy:.2f}%")
    
    if epoch_val_accuracy > best_val_accuracy:
        best_val_accuracy = epoch_val_accuracy
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f"✨ New best model saved with accuracy: {best_val_accuracy:.2f}%")
    print("-" * 50)
print(f"Finished Training. Best model saved to {BEST_MODEL_PATH}")

PyTorch Version: 2.8.0+cpu

Training on device: cpu

Found 4304 images in the training folder.

Found 928 images in the validation folder.

Original class counts in training set: Counter({1: 3203, 0: 1101})

Weight for NORMAL class: 1.95

Weight for PNEUMONIA class: 0.67

Criterion updated with class weights.

Starting Training...

Epoch [1/10] | Training Loss: 0.1769
Epoch [1/10] | Validation Loss: 0.1135 | Validation Accuracy: 95.91%
--------------------------------------------------
Epoch [2/10] | Training Loss: 0.1062
Epoch [2/10] | Validation Loss: 0.0553 | Validation Accuracy: 98.06%
--------------------------------------------------
Epoch [3/10] | Training Loss: 0.0626
Epoch [3/10] | Validation Loss: 0.2103 | Validation Accuracy: 93.53%
--------------------------------------------------
Epoch [4/10] | Training Loss: 0.0670
Epoch [4/10] | Validation Loss: 0.0504 | Validation Accuracy: 98.28%
--------------------------------------------------
Epoch [5/10] | Training Loss: 0.0520
Epoch [5/10] | Validation Loss: 0.0693 | Validation Accuracy: 97.41%
--------------------------------------------------
Epoch [6/10] | Training Loss: 0.0540
Epoch [6/10] | Validation Loss: 0.0474 | Validation Accuracy: 98.17%
--------------------------------------------------
Epoch [7/10] | Training Loss: 0.0419
Epoch [7/10] | Validation Loss: 0.1007 | Validation Accuracy: 96.44%
--------------------------------------------------
Epoch [8/10] | Training Loss: 0.0352
Epoch [8/10] | Validation Loss: 0.0523 | Validation Accuracy: 98.17%
--------------------------------------------------
Epoch [9/10] | Training Loss: 0.0246
Epoch [9/10] | Validation Loss: 0.0350 | Validation Accuracy: 98.49%
--------------------------------------------------
Epoch [10/10] | Training Loss: 0.0305
Epoch [10/10] | Validation Loss: 0.0641 | Validation Accuracy: 97.20%
--------------------------------------------------
Finished Training.