In [1]:
# STEP 0: Set project root for src imports


In [8]:
import sys, os
sys.path.append(os.path.abspath(".."))

In [9]:
# Step 1 :Imports

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as T
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, confusion_matrix

from src.dataset import ChestXrayDataset
from src.model import get_model


In [11]:
# Step 2 : Device Setup

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

Using device: mps


In [12]:
# Step 3 : Transforms $ Datasets

train_transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor()
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

train_dataset = ChestXrayDataset(
    csv_file="../data/processed/train_small.csv",
    image_dir="../data/images",
    transform=train_transform
)

val_dataset = ChestXrayDataset(
    csv_file="../data/processed/val_small.csv",
    image_dir="../data/images",
    transform=val_transform
)

batch_size = 16  # increase for speed

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [18]:
# Step 4 : Model, Loss, Optimizer
model = get_model(pretrained=True).to(device)

# Handle class imbalance
train_df = pd.read_csv("../data/processed/train.csv")
pos = train_df["Pneumonia"].sum()
neg = len(train_df) - pos

# MPS requires float32
pos_weight = torch.tensor(neg / pos, dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = optim.Adam(model.parameters(), lr=1e-4)

print("Model, criterion, and optimizer ready!")
print("pos_weight:", pos_weight.item())

Model, criterion, and optimizer ready!
pos_weight: 82.93803405761719


In [19]:
# Step 5 : Metrics

def compute_metrics(y_true, y_pred):
    """
    Compute AUC, sensitivity, specificity for binary classification.
    """
    y_pred_bin = (y_pred > 0.5).astype(int)
    auc = roc_auc_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred_bin).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    return auc, sensitivity, specificity



In [None]:
# Step 6 : Training & Validation Loop

num_epochs = 2

for epoch in range(num_epochs):
    # ---- Training ----
    model.train()
    running_loss = 0
    for imgs, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        imgs, labels = imgs.to(device), labels.float().to(device)
        optimizer.zero_grad()
        outputs = model(imgs).squeeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"\nEpoch {epoch+1}, Training Loss: {avg_loss:.4f}")


Training Epoch 1:   2%|▌                      | 28/1250 [00:24<17:39,  1.15it/s]

In [20]:
 # ---- Validation ----
model.eval()
all_labels = []
all_preds = []
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.float().to(device)
        outputs = model(imgs).squeeze(1)
        preds = torch.sigmoid(outputs)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

auc, sens, spec = compute_metrics(np.array(all_labels), np.array(all_preds))
print(f"Validation - AUC: {auc:.3f}, Sensitivity: {sens:.3f}, Specificity: {spec:.3f}\n")

Validation - AUC: 0.593, Sensitivity: 0.075, Specificity: 0.912



In [21]:
# Step 7 : Save Model
os.makedirs("../results", exist_ok=True)
torch.save(model.state_dict(), "../results/resnet50_pneumonia.pt")
print("Model saved to results/resnet50_pneumonia.pt")

Model saved to results/resnet50_pneumonia.pt


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, confusion_matrix

# --------------------
# Model
# --------------------
model = get_model().to(device)

# --------------------
# Handle class imbalance
# --------------------
train_df = pd.read_csv("../data/processed/train_small.csv")

pos = train_df["Pneumonia"].sum()
neg = len(train_df) - pos

pos_weight = torch.tensor(neg / pos, dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# --------------------
# Optimizer
# --------------------
optimizer = optim.Adam(model.parameters(), lr=3e-4)

# --------------------
# Metrics
# --------------------
def compute_metrics(y_true, y_pred):
    y_pred_bin = (y_pred > 0.5).astype(int)
    auc = roc_auc_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred_bin).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    return auc, sensitivity, specificity

# --------------------
# Training loop
# --------------------
num_epochs = 20
best_auc = 0

for epoch in range(num_epochs):
    # ---- Training ----
    model.train()
    running_loss = 0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        imgs, labels = imgs.to(device), labels.float().to(device)

        optimizer.zero_grad()
        outputs = model(imgs).squeeze(1)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)

    # ---- Validation ----
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.float().to(device)
            outputs = model(imgs).squeeze(1)
            probs = torch.sigmoid(outputs)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(probs.cpu().numpy())

    auc, sens, spec = compute_metrics(np.array(all_labels), np.array(all_preds))

    print(f"\nEpoch {epoch+1}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val AUC: {auc:.4f} | Sensitivity: {sens:.4f} | Specificity: {spec:.4f}")

    # ---- Save best model ----
    if auc > best_auc:
        best_auc = auc
        torch.save(model.state_dict(), "../results/best_model.pt")
        print("✅ Saved new best model")

print("Training complete.")


Epoch 1/20: 100%|███████████████████████████| 1250/1250 [41:30<00:00,  1.99s/it]



Epoch 1
Train Loss: 1.4857
Val AUC: 0.5501 | Sensitivity: 0.6866 | Specificity: 0.3671
✅ Saved new best model


Epoch 2/20: 100%|█████████████████████████| 1250/1250 [2:06:42<00:00,  6.08s/it]



Epoch 2
Train Loss: 1.3513
Val AUC: 0.5808 | Sensitivity: 0.5373 | Specificity: 0.6422
✅ Saved new best model


Epoch 3/20: 100%|███████████████████████████| 1250/1250 [52:54<00:00,  2.54s/it]



Epoch 3
Train Loss: 1.3568
Val AUC: 0.5269 | Sensitivity: 0.4179 | Specificity: 0.6544


Epoch 4/20: 100%|█████████████████████████| 1250/1250 [2:33:34<00:00,  7.37s/it]



Epoch 4
Train Loss: 1.3999
Val AUC: 0.5710 | Sensitivity: 0.5522 | Specificity: 0.6104


Epoch 5/20: 100%|█████████████████████████| 1250/1250 [2:14:57<00:00,  6.48s/it]



Epoch 5
Train Loss: 1.3777
Val AUC: 0.5229 | Sensitivity: 0.6567 | Specificity: 0.3505


Epoch 6/20: 100%|███████████████████████████| 1250/1250 [19:54<00:00,  1.05it/s]



Epoch 6
Train Loss: 1.4018
Val AUC: 0.5436 | Sensitivity: 1.0000 | Specificity: 0.0103


Epoch 7/20: 100%|█████████████████████████| 1250/1250 [7:04:29<00:00, 20.38s/it]



Epoch 7
Train Loss: 1.4028
Val AUC: 0.5039 | Sensitivity: 0.1791 | Specificity: 0.8944


Epoch 8/20: 100%|███████████████████████████| 1250/1250 [33:34<00:00,  1.61s/it]



Epoch 8
Train Loss: 1.4096
Val AUC: 0.4974 | Sensitivity: 0.3582 | Specificity: 0.7103


Epoch 9/20: 100%|███████████████████████████| 1250/1250 [33:48<00:00,  1.62s/it]



Epoch 9
Train Loss: 1.3899
Val AUC: 0.5911 | Sensitivity: 0.8060 | Specificity: 0.2897
✅ Saved new best model


Epoch 10/20: 100%|██████████████████████████| 1250/1250 [59:17<00:00,  2.85s/it]



Epoch 10
Train Loss: 1.3836
Val AUC: 0.4417 | Sensitivity: 0.6269 | Specificity: 0.2321


Epoch 11/20: 100%|████████████████████████| 1250/1250 [1:02:35<00:00,  3.00s/it]



Epoch 11
Train Loss: 1.3840
Val AUC: 0.4181 | Sensitivity: 0.8507 | Specificity: 0.1145


Epoch 12/20: 100%|██████████████████████████| 1250/1250 [59:21<00:00,  2.85s/it]



Epoch 12
Train Loss: 1.4102
Val AUC: 0.5047 | Sensitivity: 0.2537 | Specificity: 0.7857


Epoch 13/20: 100%|██████████████████████████| 1250/1250 [59:31<00:00,  2.86s/it]



Epoch 13
Train Loss: 1.4643
Val AUC: 0.5413 | Sensitivity: 0.6567 | Specificity: 0.3996


Epoch 14/20: 100%|██████████████████████████| 1250/1250 [57:07<00:00,  2.74s/it]



Epoch 14
Train Loss: 1.4048
Val AUC: 0.5614 | Sensitivity: 1.0000 | Specificity: 0.0026


Epoch 15/20: 100%|██████████████████████████| 1250/1250 [57:02<00:00,  2.74s/it]



Epoch 15
Train Loss: 1.3989
Val AUC: 0.5482 | Sensitivity: 0.8657 | Specificity: 0.1847


Epoch 16/20: 100%|██████████████████████████| 1250/1250 [58:22<00:00,  2.80s/it]



Epoch 16
Train Loss: 1.4085
Val AUC: 0.4919 | Sensitivity: 0.7313 | Specificity: 0.2645


Epoch 17/20: 100%|████████████████████████| 1250/1250 [2:36:01<00:00,  7.49s/it]



Epoch 17
Train Loss: 1.4094
Val AUC: 0.5052 | Sensitivity: 0.1343 | Specificity: 0.8032


Epoch 18/20: 100%|███████████████████████| 1250/1250 [14:02:44<00:00, 40.45s/it]



Epoch 18
Train Loss: 1.4067
Val AUC: 0.5220 | Sensitivity: 0.8507 | Specificity: 0.1299


Epoch 19/20: 100%|██████████████████████████| 1250/1250 [36:01<00:00,  1.73s/it]



Epoch 19
Train Loss: 1.4097
Val AUC: 0.5566 | Sensitivity: 0.1940 | Specificity: 0.8879


Epoch 20/20: 100%|██████████████████████████| 1250/1250 [45:18<00:00,  2.17s/it]



Epoch 20
Train Loss: 1.4274
Val AUC: 0.5015 | Sensitivity: 1.0000 | Specificity: 0.0071
Training complete.
