In [None]:
!pip install -q kagglehub

import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

import kagglehub

path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")
print("Base path from kagglehub:", path)

base_dir = None
for root, dirs, files in os.walk(path):
    if os.path.basename(root) == "chest_xray":
        base_dir = root
        break

if base_dir is None:
    raise RuntimeError("لم يتم العثور على مجلد 'chest_xray' داخل المسار الذي أعاده kagglehub")

print("Found chest_xray folder at:", base_dir)

train_dir = os.path.join(base_dir, "train")
test_dir  = os.path.join(base_dir, "test")

print("Train dir:", train_dir)
print("Test dir :", test_dir)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

full_train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
full_test_dataset  = datasets.ImageFolder(test_dir,  transform=test_transform)

print("Full train size:", len(full_train_dataset))
print("Full test size :", len(full_test_dataset))
print("Classes:", full_train_dataset.class_to_idx)

train_size = min(1000, len(full_train_dataset))
test_size  = min(300, len(full_test_dataset))

np.random.seed(42)
train_indices = np.random.choice(len(full_train_dataset), train_size, replace=False)
test_indices  = np.random.choice(len(full_test_dataset),  test_size,  replace=False)

train_dataset = Subset(full_train_dataset, train_indices)
test_dataset  = Subset(full_test_dataset,  test_indices)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

print("Train subset size:", len(train_dataset))
print("Test subset size :", len(test_dataset))

def build_model():
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, 2)
    return model.to(device)

criterion = nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer, criterion, max_batches=None):
    model.train()
    correct = 0
    total = 0
    batch_count = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        batch_count += 1

        if max_batches is not None and batch_count >= max_batches:
            break

    return correct / total

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

def get_probs(model, loader):
    model.eval()
    probs = []
    labels_list = []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            softmax_probs = torch.softmax(outputs, dim=1)[:, 1]
            probs.extend(softmax_probs.cpu().numpy())
            labels_list.extend(labels.numpy())
    return np.array(probs), np.array(labels_list)

def get_preds(model, loader):
    model.eval()
    preds = []
    labels_list = []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            _, p = torch.max(outputs, 1)
            preds.extend(p.cpu().numpy())
            labels_list.extend(labels.numpy())
    return np.array(preds), np.array(labels_list)

def compute_metrics_from_cm(cm):
    tn, fp, fn, tp = cm.ravel()
    accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
    sensitivity = tp / (tp + fn + 1e-8)
    specificity = tn / (tn + fp + 1e-8)
    return accuracy, sensitivity, specificity

model_base = build_model()
optimizer_base = optim.Adam(model_base.parameters(), lr=1e-4)

print("Training Baseline Model (FAST)...")
num_epochs = 1
for epoch in range(num_epochs):
    train_acc = train_one_epoch(model_base, train_loader, optimizer_base, criterion, max_batches=30)
    test_acc  = evaluate(model_base, test_loader)
    print(f"[Baseline] Epoch {epoch+1}: Train Acc={train_acc:.3f}, Test Acc={test_acc:.3f}")

targets = [full_train_dataset[i][1] for i in train_indices]
class_count = torch.bincount(torch.tensor(targets))
class_weights = 1.0 / class_count.float()
sample_weights = [class_weights[label] for label in targets]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

balanced_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)

model_os = build_model()
optimizer_os = optim.Adam(model_os.parameters(), lr=1e-4)

print("Training Oversampling Model (FAST)...")
for epoch in range(num_epochs):
    train_acc_os = train_one_epoch(model_os, balanced_loader, optimizer_os, criterion, max_batches=30)
    test_acc_os  = evaluate(model_os, test_loader)
    print(f"[OS] Epoch {epoch+1}: Train Acc={train_acc_os:.3f}, Test Acc={test_acc_os:.3f}")

probs_b, labels = get_probs(model_base, test_loader)
probs_os, _      = get_probs(model_os,   test_loader)

auc_b  = roc_auc_score(labels, probs_b)
auc_os = roc_auc_score(labels, probs_os)

print(f"Baseline AUC     : {auc_b:.3f}")
print(f"Oversampling AUC : {auc_os:.3f}")

fpr_b, tpr_b, _ = roc_curve(labels, probs_b)
fpr_os, tpr_os, _ = roc_curve(labels, probs_os)

plt.figure(figsize=(7,5))
plt.plot(fpr_b, tpr_b, label=f"Baseline (AUC={auc_b:.2f})")
plt.plot(fpr_os, tpr_os, label=f"Oversampling (AUC={auc_os:.2f})")
plt.plot([0,1], [0,1], 'k--')
plt.title("ROC Curve Comparison")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.grid()
plt.show()

# Confusion Matrix + Accuracy/Sensitivity/Specificity

preds_b, labels_b = get_preds(model_base, test_loader)
preds_os, labels_os = get_preds(model_os, test_loader)

cm_b  = confusion_matrix(labels_b, preds_b)
cm_os = confusion_matrix(labels_os, preds_os)

print("Baseline Confusion Matrix:\n", cm_b)
print("Oversampling Confusion Matrix:\n", cm_os)

acc_b, sens_b, spec_b = compute_metrics_from_cm(cm_b)
acc_os, sens_os, spec_os = compute_metrics_from_cm(cm_os)

print("\n=== Baseline Metrics ===")
print(f"Accuracy   : {acc_b:.3f}")
print(f"Sensitivity: {sens_b:.3f}")
print(f"Specificity: {spec_b:.3f}")

print("\n=== Oversampling Metrics ===")
print(f"Accuracy   : {acc_os:.3f}")
print(f"Sensitivity: {sens_os:.3f}")
print(f"Specificity: {spec_os:.3f}")
