In [None]:
!pip install -q imbalanced-learn
import os
import glob
import cv2
import torch
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve
from collections import Counter
from imblearn.over_sampling import RandomOverSampler

class ChestXrayDataset(Dataset):
    def __init__(self, df, root_path, transform=None):
        self.transform = transform
        self.all_image_paths = []
        self.labels = []

        for image_name, label in zip(df['Image Index'], df['Finding Labels']):
            matched_paths = glob.glob(os.path.join(root_path, "images_*/images", image_name))
            if matched_paths:
                self.all_image_paths.append(matched_paths[0])
                self.labels.append(1 if 'Cardiomegaly' in label else 0)

    def __len__(self):
        return len(self.all_image_paths)

    def __getitem__(self, idx):
        image_path = self.all_image_paths[idx]
        label = torch.tensor(self.labels[idx]).float()
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if self.transform:
            image = self.transform(image)
        return image, label

# === Data Loading & Sampling ===
base_path = "/kaggle/input/data"
df = pd.read_csv(f"{base_path}/Data_Entry_2017.csv")
df = df[df['View Position'] == 'PA'].reset_index(drop=True)
df['Cardiomegaly_Label'] = df['Finding Labels'].apply(lambda x: 1 if 'Cardiomegaly' in x else 0)

train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['Cardiomegaly_Label'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['Cardiomegaly_Label'], random_state=42)

# Oversampling minority class
ros = RandomOverSampler(random_state=42)
resampled_indices, _ = ros.fit_resample(np.arange(len(train_df)).reshape(-1, 1), train_df['Cardiomegaly_Label'])
train_df = train_df.iloc[resampled_indices.flatten()]

# === Transforms ===
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])
val_test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

train_dataset = ChestXrayDataset(train_df, base_path, transform=train_transform)
val_dataset = ChestXrayDataset(val_df, base_path, transform=val_test_transform)
test_dataset = ChestXrayDataset(test_df, base_path, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

# === Model ===
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)

# Weighted Loss
pos_weight = torch.tensor([
    len(train_df[train_df.Cardiomegaly_Label == 0]) / len(train_df[train_df.Cardiomegaly_Label == 1])
]).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)  # weight decay = L2 regularization

# === Training Functions ===
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        preds = torch.sigmoid(outputs) >= 0.5
        correct += (preds.float() == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total

def evaluate(model, loader, criterion, threshold=0.5):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device).unsqueeze(1)
            outputs = model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs >= threshold).float()

            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    return total_loss / total, correct / total, np.array(all_labels), np.array(all_preds), np.array(all_probs)

# === Train with Early Stopping ===
num_epochs = 20
best_val_loss = float('inf')
patience = 3
patience_counter = 0

for epoch in range(num_epochs):
    print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, threshold=0.5)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "/kaggle/working/cardiomegaly_resnet18.pth")
        print("✅ Model improved and saved.")
    else:
        patience_counter += 1
        print(f"⏸ No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("🛑 Early stopping triggered.")
            break

# === Final Evaluation on Test Set ===
model.load_state_dict(torch.load("/kaggle/working/cardiomegaly_resnet18.pth"))
test_loss, test_acc, y_true, y_pred, y_prob = evaluate(model, test_loader, criterion, threshold=0.6)

print("\n=== Test Results ===")
print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
print("Classification Report:\n", classification_report(y_true, y_pred, digits=4))
print("ROC-AUC Score:", roc_auc_score(y_true, y_prob))

# === Plot ROC & PR Curve ===
fpr, tpr, _ = roc_curve(y_true, y_prob)
prec, rec, _ = precision_recall_curve(y_true, y_prob)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, label="ROC Curve (AUC = {:.4f})".format(roc_auc_score(y_true, y_prob)))
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(rec, prec, label="PR Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.tight_layout()
plt.show()
