In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from medmnist import ChestMNIST
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("GPU Name:", torch.cuda.get_device_name(0))

Using device: cuda
GPU Name: NVIDIA GeForce GTX 1060


In [3]:
# Define transformation: converts to tensor and scales pixels to [0,1]
transform = transforms.Compose([
    transforms.ToTensor()  # Already converts [H, W] to [1, H, W] for grayscale images
])

# Load the ChestMNIST training and test sets (size=64)
train_dataset = ChestMNIST(split='train', download=True, transform=transform, size=64)
test_dataset  = ChestMNIST(split='test', download=True, transform=transform, size=64)
val_dataset  = ChestMNIST(split='val', download=True, transform=transform, size=64)


In [4]:
class ChestMNISTMultiLabel(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        img, label = self.base_dataset[idx]  # label is a numpy array of length 14
        label = torch.tensor(label.astype(np.float32))
        return img, label

# Wrap the datasets
train_data = ChestMNISTMultiLabel(train_dataset)
val_data   = ChestMNISTMultiLabel(val_dataset)
test_data  = ChestMNISTMultiLabel(test_dataset)

In [5]:
# Data loaders
batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [6]:
from torch.utils.data import Dataset
from PIL import Image
import glob
import torch
import torchvision.transforms as transforms

class SyntheticPneumoniaDataset(Dataset):
    def __init__(self, image_dir, img_size=64):
        self.image_paths = sorted(glob.glob(f"{image_dir}/*.png"))
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('L')  # grayscale
        img = self.transform(img)

        label = torch.zeros(14, dtype=torch.float32)  # 👈 Must be a Tensor!
        label[6] = 1.0
        return img, label

In [7]:
synthetic_dataset = SyntheticPneumoniaDataset("filtered_confidence_500/pneumonia")
img, label = synthetic_dataset[0]
print(type(img), img.shape)
print(type(label), label.shape, label)

<class 'torch.Tensor'> torch.Size([1, 64, 64])
<class 'torch.Tensor'> torch.Size([14]) tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])


In [8]:
from torch.utils.data import ConcatDataset

augmented_train_dataset = ConcatDataset([train_data, synthetic_dataset])
augmented_train_loader = DataLoader(augmented_train_dataset, batch_size=64, shuffle=True)


In [9]:
class ChestMNISTMultiLabelCNN(nn.Module):
    def __init__(self):
        super(ChestMNISTMultiLabelCNN, self).__init__()
        # Input shape: (B, 1, 64, 64)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # -> (B, 32, 64, 64)
        self.pool  = nn.MaxPool2d(2, 2)                           # -> halves spatial dims
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # -> (B, 64, 32, 32)
        # After another pooling, spatial size becomes 16x16: (B,64,16,16)
        self.fc1   = nn.Linear(64 * 16 * 16, 128)
        self.fc2   = nn.Linear(128, 14)  # 14 outputs for 14 labels
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (B, 32, 32, 32)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 64, 16, 16)
        x = x.view(x.size(0), -1)             # Flatten
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))        # Sigmoid: independent probability for each class
        return x

model = ChestMNISTMultiLabelCNN().to(device)
print(model)

ChestMNISTMultiLabelCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=16384, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=14, bias=True)
)


In [10]:
criterion = nn.BCELoss()  # Binary Cross-Entropy for multi-label outputs
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and early stopping parameters
num_epochs = 50         # Maximum number of epochs
patience = 3            # Stop if no improvement in val loss for 3 consecutive epochs
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in augmented_train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        
    epoch_loss = running_loss / len(augmented_train_loader.dataset)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
    val_loss = val_loss / len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}")

    # Check for improvement; if not, increase counter and possibly early stop
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print("Training complete.")


Epoch 1/50 - Train Loss: 0.1788 - Val Loss: 0.1705
Epoch 2/50 - Train Loss: 0.1719 - Val Loss: 0.1679
Epoch 3/50 - Train Loss: 0.1687 - Val Loss: 0.1678
Epoch 4/50 - Train Loss: 0.1658 - Val Loss: 0.1670
Epoch 5/50 - Train Loss: 0.1630 - Val Loss: 0.1668
Epoch 6/50 - Train Loss: 0.1601 - Val Loss: 0.1665
Epoch 7/50 - Train Loss: 0.1571 - Val Loss: 0.1666
Epoch 8/50 - Train Loss: 0.1539 - Val Loss: 0.1686
Epoch 9/50 - Train Loss: 0.1502 - Val Loss: 0.1691
Early stopping triggered after 9 epochs
Training complete.


In [11]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        all_preds.append(outputs.cpu())
        all_labels.append(targets.cpu())

# Stack predictions & labels into full batches
import torch
all_preds = torch.cat(all_preds, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy()

In [12]:
from sklearn.metrics import classification_report, roc_auc_score

# Binarize predictions (threshold at 0.5)
binarized_preds = (all_preds >= 0.5).astype(int)

print("******* Overall Classification Report *******")
print(classification_report(all_labels, binarized_preds, zero_division=0))

# Compute overall AUCs
macro_auc = roc_auc_score(all_labels, all_preds, average='macro')
micro_auc = roc_auc_score(all_labels, all_preds, average='micro')
print(f"\n******* Overall AUC Scores *******")
print(f"Macro AUC: {macro_auc:.3f}")
print(f"Micro AUC: {micro_auc:.3f}")


******* Overall Classification Report *******
              precision    recall  f1-score   support

           0       0.46      0.02      0.04      2420
           1       0.49      0.06      0.11       582
           2       0.53      0.09      0.15      2754
           3       0.49      0.04      0.08      3938
           4       0.28      0.01      0.02      1133
           5       0.00      0.00      0.00      1335
           6       0.00      0.00      0.00       242
           7       0.31      0.01      0.02      1089
           8       0.00      0.00      0.00       957
           9       0.19      0.01      0.01       413
          10       0.00      0.00      0.00       509
          11       0.00      0.00      0.00       362
          12       0.00      0.00      0.00       734
          13       0.00      0.00      0.00        42

   micro avg       0.48      0.03      0.06     16510
   macro avg       0.20      0.02      0.03     16510
weighted avg       0.33      0.03 

In [13]:
model.eval()
probabilities = []

for img, _ in synthetic_dataset:
    img_tensor = img.unsqueeze(0).to(device)  # shape: (1, 1, 64, 64)
    with torch.no_grad():
        output = torch.sigmoid(model(img_tensor))
    probabilities.append(output[0, 6].item())  # confidence for pneumonia

print("Pneumonia confidence on synthetic images:")
print(f"Mean: {np.mean(probabilities):.4f}")
print(f"Max:  {np.max(probabilities):.4f}")
print(f"Min:  {np.min(probabilities):.4f}")

Pneumonia confidence on synthetic images:
Mean: 0.7282
Max:  0.7311
Min:  0.6362


In [14]:
model.eval()
val_probs = []
val_labels = []

for img, label in val_loader:
    img = img.to(device)
    label = label.to(device)

    with torch.no_grad():
        output = torch.sigmoid(model(img))  # keep sigmoid here if using logits

    val_probs.append(output[:, 6].cpu())     # pneumonia confidence
    val_labels.append(label[:, 6].cpu())     # ground truth pneumonia label

val_probs = torch.cat(val_probs).numpy()
val_labels = torch.cat(val_labels).numpy()

import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score

print("Pneumonia Confidence Stats:")
print(f"Mean (all): {val_probs.mean():.4f}")
print(f"Mean (pneumonia): {val_probs[val_labels == 1].mean():.4f}")
print(f"Mean (non-pneumonia): {val_probs[val_labels == 0].mean():.4f}")

print("\nAUC:", roc_auc_score(val_labels, val_probs))

# Optional: see how many cases pass 0.5 threshold
binary_preds = val_probs >= 0.5
precision = (binary_preds[val_labels == 1].sum()) / max(binary_preds.sum(), 1)
recall = (binary_preds[val_labels == 1].sum()) / max((val_labels == 1).sum(), 1)
f1 = 2 * precision * recall / max((precision + recall), 1e-8)

print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}")

Pneumonia Confidence Stats:
Mean (all): 0.5026
Mean (pneumonia): 0.5061
Mean (non-pneumonia): 0.5026

AUC: 0.6715165371483914
Precision: 0.012, Recall: 1.000, F1: 0.023
