# 3-Class Chest X-ray Classifier (Normal / Pneumonia / Tuberculosis)

**Fast, beginner-friendly baseline** that fine-tunes a pretrained CNN and produces balanced metrics + clear visualizations (confusion matrix, ROC/PR curves, Grad-CAM).

> **Note:** This notebook is for educational/recruitment purposes and **not a medical device**. Do not use for clinical decisions.

In [None]:
# If running on Kaggle, most libs are available; installing grad-cam just in case.
%pip -q install timm torchmetrics grad-cam --upgrade

In [None]:
import os, sys, time, math, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve

import timm  # optional (not strictly needed for resnet18)
from torchmetrics.functional import accuracy

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

## Data
This notebook assumes a Kaggle dataset with the following structure (train/val/test folders already provided). We **ignore** the `Covid-19` class and keep the three classes we need.

In [None]:
# Change this to your dataset root.
# On Kaggle, after adding the dataset "chest-xray-pneumoniacovid19tuberculosis",
# the path is typically:
data_root = Path('/kaggle/input/chest-xray-pneumoniacovid19tuberculosis')

# If you're on Colab, set your own path accordingly.
if not data_root.exists():
    print("WARNING: data_root not found. Set 'data_root' to your dataset path and re-run this cell.")

wanted = {'normal', 'pneumonia', 'tuberculosis'}  # target class names (lowercase match)
splits = ['train', 'val', 'test']

def infer_class_name(p: Path):
    """Map a folder name to our desired class name in lowercase."""
    name = p.name.lower().strip()
    # Normalize a few common variants
    name = name.replace(' ', '').replace('_', '').replace('-', '')
    if 'normal' in name:
        return 'normal'
    if 'pneumonia' in name:
        return 'pneumonia'
    if 'tuberculosis' in name or name.startswith('tb'):
        return 'tuberculosis'
    if 'covid' in name:
        return 'covid'
    return name  # default

In [None]:
from torchvision import transforms, datasets

IMG_SIZE = 224
BATCH_SIZE = 32

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.85, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
])

eval_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
])

class FilteredImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, wanted_classes=None):
        super().__init__(root=root, transform=transform, target_transform=target_transform)
        if wanted_classes is None:
            wanted_classes = {'normal','pneumonia','tuberculosis'}
        self.class_to_name = {i: infer_class_name(Path(c)) for i, c in enumerate(self.classes)}
        keep_indices = [i for i, c in enumerate(self.classes) if infer_class_name(Path(c)) in wanted_classes]
        kept_names = sorted(list({infer_class_name(Path(self.classes[i])) for i in keep_indices}))
        self.name_to_newidx = {name: j for j, name in enumerate(kept_names)}
        self.new_classes = kept_names

        new_samples = []
        for (path, orig_idx) in self.samples:
            name = self.class_to_name[orig_idx]
            if name in wanted_classes:
                new_samples.append((path, self.name_to_newidx[name]))
        self.samples = new_samples
        self.targets = [t for _, t in self.samples]
        self.classes = self.new_classes
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

def make_loader(split, transform):
    root = data_root / split
    ds = FilteredImageFolder(root=root, transform=transform, wanted_classes={'normal','pneumonia','tuberculosis'})
    return ds, DataLoader(ds, batch_size=BATCH_SIZE, shuffle=(split=='train'), num_workers=2, pin_memory=True)

train_ds, train_loader = make_loader('train', train_tfms)
val_ds,   val_loader   = make_loader('val',   eval_tfms)
test_ds,  test_loader  = make_loader('test',  eval_tfms)

print("Classes:", train_ds.classes)
print("Train/Val/Test sizes:", len(train_ds), len(val_ds), len(test_ds))

In [None]:
from collections import Counter

counts = Counter(train_ds.targets)
num_classes = len(train_ds.classes)
class_counts = [counts.get(i, 0) for i in range(num_classes)]
class_weights = [0 if c==0 else (sum(class_counts)/ (num_classes * c)) for c in class_counts]
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

print("Class counts:", dict(zip(train_ds.classes, class_counts)))
print("Class weights:", dict(zip(train_ds.classes, [round(w,3) for w in class_weights])))

In [None]:
# Using torchvision's ResNet18 for simplicity + speed
from torchvision.models import resnet18, ResNet18_Weights

weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)

# Replace the classifier head for 3 classes + add Dropout for regularization
in_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features, 3)
)
model = model.to(device)

# Loss (with label smoothing) + optimizer + scheduler
criterion = nn.CrossEntropyLoss(label_smoothing=0.05, weight=class_weights_tensor)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
EPOCHS = 8
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# Optionally freeze the backbone to speed up
# for p in list(model.parameters())[:-2]:
#     p.requires_grad = False

In [None]:
from sklearn.metrics import f1_score

def run_epoch(dataloader, train=True):
    model.train(train)
    total_loss, total_correct, total = 0.0, 0, 0
    all_targets, all_preds = [], []
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)
        if train:
            optimizer.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train):
            logits = model(imgs)
            loss = criterion(logits, labels)
            if train:
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        pred = logits.argmax(1)
        total_correct += (pred == labels).sum().item()
        total += imgs.size(0)
        all_targets.extend(labels.detach().cpu().numpy().tolist())
        all_preds.extend(pred.detach().cpu().numpy().tolist())
    acc = total_correct / max(total,1)
    f1 = f1_score(all_targets, all_preds, average='macro')
    return total_loss/max(total,1), acc, f1

history = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[], 'train_f1':[], 'val_f1':[]}

best_val_f1, best_state, patience, patience_ctr = -1, None, 2, 0

for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc, tr_f1 = run_epoch(train_loader, train=True)
    va_loss, va_acc, va_f1 = run_epoch(val_loader, train=False)
    scheduler.step()

    history['train_loss'].append(tr_loss); history['val_loss'].append(va_loss)
    history['train_acc'].append(tr_acc);   history['val_acc'].append(va_acc)
    history['train_f1'].append(tr_f1);     history['val_f1'].append(va_f1)

    print(f"Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f} || val loss {va_loss:.4f} acc {va_acc:.3f} f1 {va_f1:.3f}")

    if va_f1 > best_val_f1:
        best_val_f1 = va_f1
        best_state = {'model': model.state_dict(), 'epoch': epoch}
        patience_ctr = 0
    else:
        patience_ctr += 1
        if patience_ctr > patience:
            print('Early stopping.')
            break

# Save best
os.makedirs('checkpoints', exist_ok=True)
torch.save(best_state, 'checkpoints/best_resnet18.pt')

# Plot curves
os.makedirs('reports', exist_ok=True)
plt.figure(); plt.plot(history['train_loss']); plt.plot(history['val_loss']); plt.title('Loss'); plt.legend(['train','val']); plt.xlabel('epoch'); plt.ylabel('loss'); plt.savefig('reports/curves_loss.png'); plt.close()
plt.figure(); plt.plot(history['train_f1']); plt.plot(history['val_f1']); plt.title('Macro F1'); plt.legend(['train','val']); plt.xlabel('epoch'); plt.ylabel('f1'); plt.savefig('reports/curves_f1.png'); plt.close()

In [None]:
# Evaluate on test set
model.load_state_dict(best_state['model'])
model.eval()

all_probs, all_preds, all_targets = [], [], []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        logits = model(imgs)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        preds = probs.argmax(axis=1)
        all_probs.append(probs)
        all_preds.extend(preds.tolist())
        all_targets.extend(labels.numpy().tolist())

all_probs = np.concatenate(all_probs, axis=0)
target_names = test_ds.classes

print("Classification report (test):")
print(classification_report(all_targets, all_preds, target_names=target_names, digits=4))

# Confusion matrix
cm = confusion_matrix(all_targets, all_preds, labels=list(range(len(target_names))))
fig = plt.figure()
plt.imshow(cm, interpolation='nearest')
plt.title('Confusion Matrix (Test)')
plt.colorbar()
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
plt.xlabel('Predicted'); plt.ylabel('True')
plt.tight_layout()
os.makedirs('reports', exist_ok=True)
plt.savefig('reports/confusion_matrix.png'); plt.close(fig)

In [None]:
# ROC-AUC (OvR) and PR curves
y_true = np.array(all_targets)
Y = np.zeros((len(y_true), len(target_names)))
Y[np.arange(len(y_true)), y_true] = 1
y_score = all_probs

# ROC-AUC macro/micro
try:
    roc_macro = roc_auc_score(Y, y_score, multi_class='ovr', average='macro')
    roc_micro = roc_auc_score(Y, y_score, multi_class='ovr', average='micro')
    print(f"ROC-AUC macro: {roc_macro:.4f} | micro: {roc_micro:.4f}")
except Exception as e:
    print("ROC-AUC not available:", e)

# Plot ROC per class
for i, cname in enumerate(target_names):
    fpr, tpr, _ = roc_curve(Y[:, i], y_score[:, i])
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0,1],[0,1],'--')
    plt.title(f'ROC Curve – {cname}')
    plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
    plt.savefig(f'reports/roc_{cname}.png'); plt.close()

# Plot PR per class
for i, cname in enumerate(target_names):
    prec, rec, _ = precision_recall_curve(Y[:, i], y_score[:, i])
    plt.figure()
    plt.plot(rec, prec)
    plt.title(f'Precision-Recall Curve – {cname}')
    plt.xlabel('Recall'); plt.ylabel('Precision')
    plt.savefig(f'reports/pr_{cname}.png'); plt.close()

In [None]:
# Grad-CAM visualization on a few test images
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

# Pick the last conv block for ResNet18
target_layers = [model.layer4[-1]]

cam = GradCAM(model=model, target_layers=target_layers, use_cuda=(device=='cuda'))

def tensor_to_rgb(img_tensor):
    # img_tensor: (C,H,W), normalized
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    img = img_tensor.numpy().transpose(1,2,0)
    img = (img * std) + mean
    img = np.clip(img, 0, 1)
    return img

# Collect a few samples
samples = []
for i in range(min(12, len(test_ds))):
    img, label = test_ds[i]
    samples.append((img, label))

for i, (img, label) in enumerate(samples):
    input_tensor = img.unsqueeze(0).to(device)
    grayscale_cam = cam(input_tensor=input_tensor, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    rgb = tensor_to_rgb(img.cpu())
    cam_image = show_cam_on_image(rgb, grayscale_cam, use_rgb=True)

    plt.figure()
    plt.imshow(cam_image)
    plt.title(f'Grad-CAM | True: {target_names[label]}')
    plt.axis('off')
    plt.savefig(f'reports/gradcam_{i}.png'); plt.close()

In [None]:
# Simple reliability diagram & ECE (estimated)
def reliability_diagram(probs, y_true, bins=10, savepath='reports/reliability.png'):
    confidences = probs.max(axis=1)
    predictions = probs.argmax(axis=1)
    accuracies = (predictions == y_true).astype(float)

    bin_edges = np.linspace(0.0, 1.0, bins+1)
    bin_ids = np.digitize(confidences, bin_edges[1:-1], right=True)

    bin_acc, bin_conf, bin_count = [], [], []
    ece = 0.0
    for b in range(bins):
        in_bin = (bin_ids == b)
        if np.any(in_bin):
            acc = accuracies[in_bin].mean()
            conf = confidences[in_bin].mean()
            cnt = in_bin.sum()
            bin_acc.append(acc); bin_conf.append(conf); bin_count.append(cnt)
            ece += (cnt/len(confidences)) * abs(acc - conf)
        else:
            bin_acc.append(0); bin_conf.append(0); bin_count.append(0)

    # Plot
    plt.figure()
    plt.plot([0,1],[0,1],'--')
    # bar centers
    centers = np.linspace(1/(2*bins), 1-1/(2*bins), bins)
    plt.bar(centers, bin_acc, width=1/bins, alpha=0.6, align='center')
    plt.title(f'Reliability Diagram (ECE≈{ece:.3f})')
    plt.xlabel('Confidence'); plt.ylabel('Accuracy')
    os.makedirs('reports', exist_ok=True)
    plt.savefig(savepath); plt.close()
    return ece

y_true_arr = np.array(all_targets)
ece = reliability_diagram(all_probs, y_true_arr, bins=10, savepath='reports/reliability.png')
print(f"Estimated ECE: {ece:.4f}")

## Ethics, Fairness & Next Steps
- **Not for clinical use**. Requires radiologist oversight and regulatory approval for real deployment.
- **Bias & domain shift**: This dataset may not represent all scanners/populations; test on **external datasets** before relying on it.
- **Interpretability**: Grad-CAM should highlight lung regions. If it highlights text markers or borders, that's a red flag for spurious correlations.
- **Calibration**: We plotted a simple reliability diagram to understand probability quality. A more robust method is **temperature scaling** on the validation set.
- **Improvements**: Try EfficientNet-B0, add **lung segmentation** as a preprocessing step, and report **per-source subgroup metrics** if multiple hospitals are present.