In [None]:
# ========== CELL 2: DATA PREPARATION  ==========
from google.colab import drive
drive.mount('/content/drive')

# Config
DATA_DIR = "/content/drive/MyDrive/COVID-CXR-Dataset"
IMG_SIZE = 224
BATCH_SIZE = 32

# Transforms
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# Load data
train_ds = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_tfms)
val_ds = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_tfms)

train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=2)

print(f"âœ… Train: {len(train_ds)} images, Val: {len(val_ds)} images")
print(f"âœ… Classes: {train_ds.classes}")


In [None]:
# Install & Setup
!pip install -q timm scikit-learn pandas matplotlib seaborn

import torch
import torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os, random, time, json
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, roc_auc_score,
                           precision_recall_fscore_support, roc_curve, confusion_matrix)
import matplotlib.pyplot as plt
import seaborn as sns

print("âœ… Environment ready")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}")

In [None]:
import torch
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import numpy as np
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# Load best ViT-Tiny model
best_model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=2)
best_model.load_state_dict(torch.load("best_vit_tiny_aug.pth", map_location=device))
best_model = best_model.to(device)
best_model.eval()

all_labels, all_preds, all_probs = [], [], []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = best_model(images)
        probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
        _, preds = torch.max(outputs, 1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs)

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)

# Classification Report
print("=== ViT-Tiny Classification Report ===")
print(classification_report(all_labels, all_preds,
                          target_names=['Normal (0)', 'COVID (1)'],
                          digits=4))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
tn, fp, fn, tp = cm.ravel()
print("\n=== Confusion Matrix ===")
print(cm)
print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")

# Key Metrics
sensitivity = tp / (tp + fn + 1e-8)
specificity = tn / (tn + fp + 1e-8)
accuracy = (tp + tn) / (tp + tn + fp + fn)
auc = roc_auc_score(all_labels, all_probs)

print(f"\nAccuracy           : {accuracy*100:.2f}%")
print(f"Sensitivity (COVID): {sensitivity*100:.2f}%")
print(f"Specificity (Normal): {specificity*100:.2f}%")
print(f"AUC                : {auc*100:.2f}%")

# Save results
metrics = pd.DataFrame({
    'model': ['ViT-Tiny'],
    'accuracy': [accuracy],
    'auc': [auc],
    'sensitivity_covid': [sensitivity],
    'specificity_normal': [specificity]
})
metrics.to_csv('vit_tiny_metrics.csv', index=False)
print("\nâœ… Results saved to vit_tiny_metrics.csv")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from sklearn.metrics import roc_auc_score
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("âœ… Device:", device)

# ResNet-18 model
resnet = timm.create_model('resnet18', pretrained=True, num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(resnet.parameters(), lr=3e-4, weight_decay=1e-4)

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_labels, all_probs = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)
    auc = roc_auc_score(all_labels, all_probs)
    return running_loss / total, correct / total, auc, all_labels, all_probs

# Training loop
num_epochs = 20
best_val_auc = -1.0  # Use AUC instead of accuracy
patience = 5
no_improve = 0

print("ðŸš€ Starting ResNet-18 Training...")
for epoch in range(num_epochs):
    start_time = time.time()

    # Train
    train_loss, train_acc = train_one_epoch(resnet, train_loader, optimizer, criterion)

    # Evaluate
    val_loss, val_acc, val_auc, _, _ = evaluate(resnet, val_loader, criterion)

    epoch_time = time.time() - start_time

    print(f"Epoch {epoch+1:2d}/{num_epochs}")
    print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc*100:.2f}%")
    print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc*100:.2f}%, AUC={val_auc:.4f} ({epoch_time:.1f}s)")

    # Early stopping based on AUC
    if val_auc > best_val_auc + 1e-4:
        best_val_auc = val_auc
        torch.save({
            'model_state_dict': resnet.state_dict(),
            'epoch': epoch,
            'val_auc': val_auc,
            'val_acc': val_acc
        }, "best_resnet18_covid.pth")
        no_improve = 0
        print("  âœ… Best model saved!")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("  ðŸ›‘ Early stopping")
            break

print(f"\nðŸŽ¯ Best ResNet-18 validation AUC: {best_val_auc*100:.2f}%")
print("âœ… Training completed!")

In [None]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("âœ… Ensemble Evaluation - Device:", device)

# Load best models
vit = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=2).to(device)
vit.load_state_dict(torch.load("best_vit_tiny_aug.pth", map_location=device))
vit.eval()

resnet = timm.create_model('resnet18', pretrained=False, num_classes=2).to(device)
resnet.load_state_dict(torch.load("best_resnet18_covid.pth", map_location=device))
resnet.eval()

print("âœ… Models loaded successfully")

all_labels, all_vit_probs, all_res_probs, all_ens_probs, all_preds = [], [], [], [], []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        # Individual model predictions
        out_vit = vit(images)
        out_res = resnet(images)

        prob_vit = torch.softmax(out_vit, dim=1)[:, 1].cpu().numpy()
        prob_res = torch.softmax(out_res, dim=1)[:, 1].cpu().numpy()

        # Ensemble: Simple probability averaging
        prob_ens = (prob_vit + prob_res) / 2.0
        pred_ens = (prob_ens >= 0.5).astype(int)

        all_labels.extend(labels.cpu().numpy())
        all_vit_probs.extend(prob_vit)
        all_res_probs.extend(prob_res)
        all_ens_probs.extend(prob_ens)
        all_preds.extend(pred_ens)

# Convert to numpy arrays
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_ens_probs = np.array(all_ens_probs)

print("=== CNNâ€“ViT Ensemble Results ===")
print(classification_report(all_labels, all_preds,
                          target_names=['Normal (0)', 'COVID (1)'],
                          digits=4))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
tn, fp, fn, tp = cm.ravel()
print(f"\nConfusion Matrix:\n{cm}")
print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")

# Key Metrics
sensitivity = tp / (tp + fn + 1e-8)
specificity = tn / (tn + fp + 1e-8)
accuracy = (tp + tn) / (tp + tn + fp + fn)
auc = roc_auc_score(all_labels, all_ens_probs)

print(f"\nðŸŽ¯ FINAL ENSEMBLE METRICS:")
print(f"Accuracy           : {accuracy*100:.2f}%")
print(f"Sensitivity (COVID): {sensitivity*100:.2f}%")
print(f"Specificity (Normal): {specificity*100:.2f}%")
print(f"AUC                : {auc*100:.2f}%")

# Save comprehensive results
results = pd.DataFrame({
    'y_true': all_labels,
    'vit_prob': all_vit_probs,
    'resnet_prob': all_res_probs,
    'ensemble_prob': all_ens_probs,
    'ensemble_pred': all_preds
})
results.to_csv('ensemble_predictions.csv', index=False)

summary = pd.DataFrame({
    'model': ['CNN-ViT_Ensemble'],
    'accuracy': [accuracy],
    'auc': [auc],
    'sensitivity_covid': [sensitivity],
    'specificity_normal': [specificity]
})
summary.to_csv('ensemble_metrics.csv', index=False)

print("\nâœ… Results saved:")
print("  - ensemble_predictions.csv (detailed)")
print("  - ensemble_metrics.csv (summary)")

# ROC Curve
fpr, tpr, _ = roc_curve(all_labels, all_ens_probs)
plt.figure(figsize=(6, 5))
plt.plot(fpr, tpr, linewidth=2, label=f'Ensemble (AUC={auc:.3f})')
plt.plot([0,1], [0,1], 'k--', alpha=0.5)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('CNNâ€“ViT Ensemble ROC Curve')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.savefig('ensemble_roc.png', dpi=300, bbox_inches='tight')
plt.show()

print("âœ… ROC curve saved as ensemble_roc.png")
