In [None]:
# ------------------------
# AE-CNN0 Architecture
# ------------------------
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

CLASS_NAMES = extract_labels
CLASS_INDEX = {name: idx for idx, name in enumerate(CLASS_NAMES)}

class ChestXrayDataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None, limit=None):
        self.image_dir = image_dir
        self.transform = transform
        df = pd.read_csv(csv_path)
        if limit:
            df = df[:limit]
        df['exists'] = df['Path'].apply(lambda x: os.path.exists(os.path.join(image_dir, str(x))))
        df = df[df['exists']].drop(columns='exists')
        label_counts = df[CLASS_NAMES].sum().sort_values(ascending=False)
        print("\U0001f50d Label counts in dataset:")
        print(label_counts[label_counts > 0])
        self.image_paths = df['Path'].values
        self.labels = df[CLASS_NAMES].values

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_paths[idx])
        img = Image.open(img_path).convert('L')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label

class DenseNet121(nn.Module):
    def __init__(self, classCount, isTrained):
        super(DenseNet121, self).__init__()
        self.densenet121 = models.densenet121(pretrained=isTrained)
        kernelCount = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(kernelCount, classCount),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.densenet121(x)

class AECNN0(nn.Module):
    def __init__(self, classCount):
        super(AECNN0, self).__init__()
        self.classCount = classCount
        self.normalize = transforms.Normalize([0.485]*3, [0.229]*3)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=4, padding=2),
            nn.ELU(),
            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
        )
        self.classifier = DenseNet121(classCount=self.classCount, isTrained=True)
    def forward(self, x):
        y = self.encoder(x)
        y = y.clamp(min=0, max=1)
        y3 = y.repeat(1, 3, 1, 1)
        y3 = self.normalize(y3)
        out = self.classifier(y3)
        return y, out

def compute_auc(targets, outputs):
    targets = np.array(targets)
    outputs = np.array(outputs)
    aucs = []
    for i in range(targets.shape[1]):
        try:
            auc = roc_auc_score(targets[:, i], outputs[:, i])
        except:
            auc = np.nan
        aucs.append(auc)
    return aucs

def visualize_reconstruction(input_img, recon_img):
    input_img = input_img.cpu().detach().numpy().squeeze()
    recon_img = recon_img.cpu().detach().numpy().squeeze()
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(input_img, cmap='gray')
    axs[0].set_title("Original")
    axs[1].imshow(recon_img, cmap='gray')
    axs[1].set_title("Reconstructed")
    plt.show()

def plot_auc_bar(class_names, aucs):
    plt.figure(figsize=(10,5))
    sns.barplot(x=aucs, y=class_names)
    plt.xlabel("AUC Score")
    plt.title("AUC per Class")
    plt.xlim(0, 1)
    plt.grid(axis='x')
    plt.show()

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv_path = "labels.csv"
image_dir = "/kaggle/input/nih-balanced-and-resized-chest-x-rays/resized_images/resized_images"
transform = transforms.Compose([
    transforms.Resize((896, 896)),
    transforms.ToTensor()
])
dataset = ChestXrayDataset(csv_path=csv_path, image_dir=image_dir, transform=transform, limit=1000)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

model = AECNN0(classCount=len(CLASS_NAMES)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion_bce = nn.BCELoss()
criterion_mse = nn.MSELoss()

# Training
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_targets, all_outputs = [], []

    for i, (images, targets) in enumerate(dataloader):
        images, targets = images.to(device), targets.to(device)
        recon, outputs = model(images)
        images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
        loss_bce = criterion_bce(outputs, targets)
        loss_mse = criterion_mse(recon, images_resized)
        loss = 0.9 * loss_bce + 0.1 * loss_mse
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        all_targets.extend(targets.cpu().detach().numpy())
        all_outputs.extend(outputs.cpu().detach().numpy())
        if i == 0:
            visualize_reconstruction(images[0], recon[0])

    epoch_auc = compute_auc(all_targets, all_outputs)
    print(f"\nEpoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")
    print("AUC per class:")
    for name, auc in zip(CLASS_NAMES, epoch_auc):
        print(f"{name}: {auc:.4f}")
    print(f"Average AUC: {np.nanmean(epoch_auc):.4f}")

plot_auc_bar(CLASS_NAMES, epoch_auc)

In [None]:
import random


model.eval()


sample_indices = random.sample(range(len(dataset)), 50)

for idx in sample_indices:
    img, true_label = dataset[idx]
    with torch.no_grad():
        img_input = img.unsqueeze(0).to(device)
        _, pred = model(img_input)
        pred = pred.cpu().squeeze().numpy()
    
    true_label = true_label.numpy()

    # Get disease names
    true_diseases = [CLASS_NAMES[i] for i, val in enumerate(true_label) if val == 1]
    pred_diseases = [CLASS_NAMES[i] for i, val in enumerate(pred) if val >= 0.5]

    plt.imshow(img.squeeze(), cmap='gray')
    plt.axis('off')
    plt.title(f"True: {', '.join(true_diseases)}\n  Predicted: {', '.join(pred_diseases)}")
    plt.show()


In [None]:
from sklearn.metrics import accuracy_score

# Threshold 0.5
binary_preds = (np.array(all_outputs) >= 0.5).astype(int)
true_labels = np.array(all_targets)

# Compute average accuracy per label
val_accuracy = (binary_preds == true_labels).mean()
print(f" Validation Accuracy (overall): {val_accuracy:.4f}")

In [None]:
from torchvision.transforms.functional import normalize
import cv2
import matplotlib.pyplot as plt

def generate_gradcam(model, input_tensor, target_class):
    model.eval()
    input_tensor = input_tensor.unsqueeze(0).to(device)

    gradients, activations = [], []

    def save_gradients_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0])

    def save_activation_hook(module, input, output):
        activations.append(output)

    final_conv = model.classifier.densenet121.features[-1]
    hook_a = final_conv.register_forward_hook(save_activation_hook)
    hook_g = final_conv.register_backward_hook(save_gradients_hook)

    output = model(input_tensor)[1]
    class_score = output[0][target_class]
    model.zero_grad()
    class_score.backward()

    grads = gradients[0].cpu().detach()
    acts = activations[0].cpu().detach()
    pooled_grads = grads.mean(dim=[0, 2, 3])
    for i in range(acts.shape[1]):
        acts[:, i, :, :] *= pooled_grads[i]

    heatmap = acts.mean(dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= heatmap.max()

    hook_a.remove()
    hook_g.remove()
    return heatmap.numpy()

#Grad-CAM for N samples
num_samples = 4
for idx in range(num_samples):
    img, label = dataset[idx]
    input_tensor = img.unsqueeze(0).to(device)
    output = model(input_tensor)[1]
    probs = output.cpu().detach().numpy().squeeze()

    pred_class = int(np.argmax(probs))
    prob = probs[pred_class]
    true_class = int(np.argmax(label.numpy()))  # get one active label

    heatmap = generate_gradcam(model, img, target_class=pred_class)

    plt.figure(figsize=(5, 5))
    plt.imshow(img.squeeze(), cmap='gray')
    plt.imshow(cv2.resize(heatmap, (896, 896)), cmap='jet', alpha=0.5)
    plt.title(
        f"Pred: {CLASS_NAMES[pred_class]} ({prob:.2f})\nTrue: {CLASS_NAMES[true_class]}"
    )
    plt.axis('off')
    plt.show()