In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from medmnist import ChestMNIST
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("GPU Name:", torch.cuda.get_device_name(0))

Using device: cuda
GPU Name: NVIDIA GeForce GTX 1060


In [3]:
# Define transformation: converts to tensor and scales pixels to [0,1]
transform = transforms.Compose([
    transforms.ToTensor()  # Already converts [H, W] to [1, H, W] for grayscale images
])

# Load the ChestMNIST training and test sets (size=64)
train_dataset = ChestMNIST(split='train', download=True, transform=transform, size=64)
test_dataset  = ChestMNIST(split='test', download=True, transform=transform, size=64)
val_dataset  = ChestMNIST(split='val', download=True, transform=transform, size=64)


In [4]:
class ChestMNISTMultiLabel(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        img, label = self.base_dataset[idx]  # label is a numpy array of length 14
        label = torch.tensor(label.astype(np.float32))
        return img, label

# Wrap the datasets
train_data = ChestMNISTMultiLabel(train_dataset)
val_data   = ChestMNISTMultiLabel(val_dataset)
test_data  = ChestMNISTMultiLabel(test_dataset)

In [5]:
# Data loaders
batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [6]:
class ChestMNISTMultiLabelCNN(nn.Module):
    def __init__(self):
        super(ChestMNISTMultiLabelCNN, self).__init__()
        # Input shape: (B, 1, 64, 64)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # -> (B, 32, 64, 64)
        self.pool  = nn.MaxPool2d(2, 2)                           # -> halves spatial dims
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # -> (B, 64, 32, 32)
        # After another pooling, spatial size becomes 16x16: (B,64,16,16)
        self.fc1   = nn.Linear(64 * 16 * 16, 128)
        self.fc2   = nn.Linear(128, 14)  # 14 outputs for 14 labels
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (B, 32, 32, 32)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 64, 16, 16)
        x = x.view(x.size(0), -1)             # Flatten
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))        # Sigmoid: independent probability for each class
        return x

model = ChestMNISTMultiLabelCNN().to(device)
print(model)

ChestMNISTMultiLabelCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=16384, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=14, bias=True)
)


In [13]:
criterion = nn.BCELoss()  # Binary Cross-Entropy for multi-label outputs
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and early stopping parameters
num_epochs = 50         # Maximum number of epochs
patience = 3            # Stop if no improvement in val loss for 3 consecutive epochs
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        
    epoch_loss = running_loss / len(train_loader.dataset)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
    val_loss = val_loss / len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}")

    # Check for improvement; if not, increase counter and possibly early stop
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print("Training complete.")


Epoch 1/50 - Train Loss: 0.1779 - Val Loss: 0.1704
Epoch 2/50 - Train Loss: 0.1713 - Val Loss: 0.1682
Epoch 3/50 - Train Loss: 0.1681 - Val Loss: 0.1682
Epoch 4/50 - Train Loss: 0.1657 - Val Loss: 0.1659
Epoch 5/50 - Train Loss: 0.1629 - Val Loss: 0.1650
Epoch 6/50 - Train Loss: 0.1601 - Val Loss: 0.1650
Epoch 7/50 - Train Loss: 0.1571 - Val Loss: 0.1668
Epoch 8/50 - Train Loss: 0.1536 - Val Loss: 0.1666
Early stopping triggered after 8 epochs
Training complete.


In [17]:
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

# -------------------------------
# Evaluate on Test Set & Collect Predictions
# -------------------------------

model.eval()
all_probs = []    # to store predicted probabilities
all_preds = []    # to store thresholded (binary) predictions if needed for classification report
all_targets = []  # to store true targets

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)  # outputs are probabilities because model returns torch.sigmoid(x)
        all_probs.append(outputs.cpu().numpy())
        
        preds = (outputs > 0.5).float()  # Binary predictions for classification report and confusion matrix
        all_preds.append(preds.cpu().numpy())
        
        all_targets.append(targets.cpu().numpy())

all_probs = np.concatenate(all_probs, axis=0)    # shape: (n_samples, 14)
all_preds = np.concatenate(all_preds, axis=0)      # shape: (n_samples, 14)
all_targets = np.concatenate(all_targets, axis=0)  # shape: (n_samples, 14)

# -------------------------------
# Define Class Names (in order)
# -------------------------------
disease_labels = [
    "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
    "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation",
    "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]

# -------------------------------
# 1. Overall (Global) Metrics
# -------------------------------
# For overall AUC, compute both micro and macro averages.
try:
    overall_auc_macro = roc_auc_score(all_targets, all_probs, average="macro")
except ValueError as e:
    overall_auc_macro = None
    print("Error computing overall macro AUC:", e)

try:
    overall_auc_micro = roc_auc_score(all_targets, all_probs, average="micro")
except ValueError as e:
    overall_auc_micro = None
    print("Error computing overall micro AUC:", e)

# For overall classification report, flatten the arrays
flat_targets = all_targets.flatten()
flat_preds = all_preds.flatten()

overall_report = classification_report(flat_targets, flat_preds, digits=3)
overall_cm = confusion_matrix(flat_targets, flat_preds)

print("******* Overall Classification Report *******")
print(overall_report)
print("******* Overall Confusion Matrix *******")
print(overall_cm)
print("******* Overall AUC Scores *******")
print(f"Overall Macro AUC: {overall_auc_macro:.3f}" if overall_auc_macro is not None else "Overall Macro AUC: Error")
print(f"Overall Micro AUC: {overall_auc_micro:.3f}" if overall_auc_micro is not None else "Overall Micro AUC: Error")

# -------------------------------
# 2. Per-Class Metrics
# -------------------------------
print("\n******* Per-Class Metrics *******")
auc_per_class = {}
for i, label_name in enumerate(disease_labels):
    print(f"\nClassification Report for {label_name}:")
    print(classification_report(all_targets[:, i], all_preds[:, i], digits=3))
    print("Confusion Matrix:")
    print(confusion_matrix(all_targets[:, i], all_preds[:, i]))
    
    # Compute ROC AUC for each class using the probabilities.
    try:
        auc = roc_auc_score(all_targets[:, i], all_probs[:, i])
    except ValueError as e:
        auc = None
        print(f"Error computing AUC for {label_name}: {e}")
    auc_per_class[label_name] = auc
    print(f"AUC for {label_name}: {auc:.3f}" if auc is not None else "AUC computation error")

# -------------------------------
# 3. Summary of Per-Class AUC Scores
# -------------------------------
print("\n******* Summary: AUC Per Class *******")
for label, auc in auc_per_class.items():
    if auc is not None:
        print(f"{label}: {auc:.3f}")
    else:
        print(f"{label}: AUC computation error")


******* Overall Classification Report *******
              precision    recall  f1-score   support

         0.0      0.950     0.997     0.973    297552
         1.0      0.461     0.050     0.090     16510

    accuracy                          0.947    314062
   macro avg      0.706     0.523     0.532    314062
weighted avg      0.924     0.947     0.926    314062

******* Overall Confusion Matrix *******
[[296587    965]
 [ 15683    827]]
******* Overall AUC Scores *******
Overall Macro AUC: 0.733
Overall Micro AUC: 0.818

******* Per-Class Metrics *******

Classification Report for Atelectasis:
              precision    recall  f1-score   support

         0.0      0.895     0.992     0.941     20013
         1.0      0.376     0.041     0.074      2420

    accuracy                          0.889     22433
   macro avg      0.636     0.516     0.507     22433
weighted avg      0.839     0.889     0.848     22433

Confusion Matrix:
[[19849   164]
 [ 2321    99]]
AUC for Atelect

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [7]:
from torch.utils.data import Dataset
from PIL import Image
import glob
import torch
import torchvision.transforms as transforms

class SyntheticPneumoniaDataset(Dataset):
    def __init__(self, image_dir, img_size=64):
        self.image_paths = sorted(glob.glob(f"{image_dir}/*.png"))
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1, 1] range
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('L')  # grayscale
        img = self.transform(img)

        label = torch.zeros(14, dtype=torch.float32)  # 👈 Must be a Tensor!
        label[6] = 1.0
        return img, label


In [8]:
from torch.utils.data import Dataset

class TorchChestMNIST(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        img, label = self.base_dataset[idx]
        # label is a numpy array → convert to tensor
        return img, torch.tensor(label, dtype=torch.float32)


In [9]:
from medmnist import ChestMNIST
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Tanh-style
])

base_dataset = ChestMNIST(split='train', download=True, transform=transform, size=64)
real_train_dataset = TorchChestMNIST(base_dataset)  # ✅ fixes the label format


In [10]:
synthetic_dataset = SyntheticPneumoniaDataset("augmented/pneumonia")
img, label = synthetic_dataset[0]
print(type(img), img.shape)
print(type(label), label.shape, label)


<class 'torch.Tensor'> torch.Size([1, 64, 64])
<class 'torch.Tensor'> torch.Size([14]) tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])


In [11]:
from torch.utils.data import ConcatDataset

augmented_train_dataset = ConcatDataset([real_train_dataset, synthetic_dataset])
augmented_train_loader = DataLoader(augmented_train_dataset, batch_size=64, shuffle=True)


In [12]:
criterion = nn.BCELoss()  # Binary Cross-Entropy for multi-label outputs
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and early stopping parameters
num_epochs = 50         # Maximum number of epochs
patience = 3            # Stop if no improvement in val loss for 3 consecutive epochs
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in augmented_train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device).float()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        
    epoch_loss = running_loss / len(augmented_train_loader.dataset)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
    val_loss = val_loss / len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}")

    # Check for improvement; if not, increase counter and possibly early stop
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print("Training complete.")


Epoch 1/50 - Train Loss: 0.1695 - Val Loss: 0.1823
Epoch 2/50 - Train Loss: 0.1571 - Val Loss: 0.1759
Epoch 3/50 - Train Loss: 0.1527 - Val Loss: 0.1784
Epoch 4/50 - Train Loss: 0.1488 - Val Loss: 0.1809
Epoch 5/50 - Train Loss: 0.1445 - Val Loss: 0.1796
Early stopping triggered after 5 epochs
Training complete.


In [13]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        all_preds.append(outputs.cpu())
        all_labels.append(targets.cpu())

# Stack predictions & labels into full batches
import torch
all_preds = torch.cat(all_preds, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy()


In [14]:
from sklearn.metrics import classification_report, roc_auc_score

# Binarize predictions (threshold at 0.5)
binarized_preds = (all_preds >= 0.5).astype(int)

print("******* Overall Classification Report *******")
print(classification_report(all_labels, binarized_preds, zero_division=0))

# Compute overall AUCs
macro_auc = roc_auc_score(all_labels, all_preds, average='macro')
micro_auc = roc_auc_score(all_labels, all_preds, average='micro')
print(f"\n******* Overall AUC Scores *******")
print(f"Macro AUC: {macro_auc:.3f}")
print(f"Micro AUC: {micro_auc:.3f}")


******* Overall Classification Report *******
              precision    recall  f1-score   support

           0       0.00      0.00      0.00      2420
           1       0.00      0.00      0.00       582
           2       0.36      0.03      0.06      2754
           3       0.00      0.00      0.00      3938
           4       0.00      0.00      0.00      1133
           5       0.00      0.00      0.00      1335
           6       0.00      0.00      0.00       242
           7       0.00      0.00      0.00      1089
           8       0.00      0.00      0.00       957
           9       0.00      0.00      0.00       413
          10       0.00      0.00      0.00       509
          11       0.00      0.00      0.00       362
          12       0.00      0.00      0.00       734
          13       0.00      0.00      0.00        42

   micro avg       0.33      0.01      0.01     16510
   macro avg       0.03      0.00      0.00     16510
weighted avg       0.06      0.01 

In [15]:
real_dataset = TorchChestMNIST(ChestMNIST(split='train', transform=transform, size=64))

for i in range(5):
    img, label = real_dataset[i]
    print(f"Real label {i}:", label)
    assert isinstance(label, torch.Tensor), "Label is not tensor"
    assert label.shape == (14,), "Label shape is not (14,)"
    assert label.dtype == torch.float32, "Label dtype is not float"


Real label 0: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 1: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 2: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 4: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [16]:
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

# -------------------------------
# Evaluate on Test Set & Collect Predictions
# -------------------------------

model.eval()
all_probs = []    # to store predicted probabilities
all_preds = []    # to store thresholded (binary) predictions if needed for classification report
all_targets = []  # to store true targets

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)  # outputs are probabilities because model returns torch.sigmoid(x)
        all_probs.append(outputs.cpu().numpy())
        
        preds = (outputs > 0.5).float()  # Binary predictions for classification report and confusion matrix
        all_preds.append(preds.cpu().numpy())
        
        all_targets.append(targets.cpu().numpy())

all_probs = np.concatenate(all_probs, axis=0)    # shape: (n_samples, 14)
all_preds = np.concatenate(all_preds, axis=0)      # shape: (n_samples, 14)
all_targets = np.concatenate(all_targets, axis=0)  # shape: (n_samples, 14)

# -------------------------------
# Define Class Names (in order)
# -------------------------------
disease_labels = [
    "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
    "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation",
    "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]

# -------------------------------
# 1. Overall (Global) Metrics
# -------------------------------
# For overall AUC, compute both micro and macro averages.
try:
    overall_auc_macro = roc_auc_score(all_targets, all_probs, average="macro")
except ValueError as e:
    overall_auc_macro = None
    print("Error computing overall macro AUC:", e)

try:
    overall_auc_micro = roc_auc_score(all_targets, all_probs, average="micro")
except ValueError as e:
    overall_auc_micro = None
    print("Error computing overall micro AUC:", e)

# For overall classification report, flatten the arrays
flat_targets = all_targets.flatten()
flat_preds = all_preds.flatten()

overall_report = classification_report(flat_targets, flat_preds, digits=3)
overall_cm = confusion_matrix(flat_targets, flat_preds)

print("******* Overall Classification Report *******")
print(overall_report)
print("******* Overall Confusion Matrix *******")
print(overall_cm)
print("******* Overall AUC Scores *******")
print(f"Overall Macro AUC: {overall_auc_macro:.3f}" if overall_auc_macro is not None else "Overall Macro AUC: Error")
print(f"Overall Micro AUC: {overall_auc_micro:.3f}" if overall_auc_micro is not None else "Overall Micro AUC: Error")

# -------------------------------
# 2. Per-Class Metrics
# -------------------------------
print("\n******* Per-Class Metrics *******")
auc_per_class = {}
for i, label_name in enumerate(disease_labels):
    print(f"\nClassification Report for {label_name}:")
    print(classification_report(all_targets[:, i], all_preds[:, i], digits=3))
    print("Confusion Matrix:")
    print(confusion_matrix(all_targets[:, i], all_preds[:, i]))
    
    # Compute ROC AUC for each class using the probabilities.
    try:
        auc = roc_auc_score(all_targets[:, i], all_probs[:, i])
    except ValueError as e:
        auc = None
        print(f"Error computing AUC for {label_name}: {e}")
    auc_per_class[label_name] = auc
    print(f"AUC for {label_name}: {auc:.3f}" if auc is not None else "AUC computation error")

# -------------------------------
# 3. Summary of Per-Class AUC Scores
# -------------------------------
print("\n******* Summary: AUC Per Class *******")
for label, auc in auc_per_class.items():
    if auc is not None:
        print(f"{label}: {auc:.3f}")
    else:
        print(f"{label}: AUC computation error")

******* Overall Classification Report *******
              precision    recall  f1-score   support

         0.0      0.948     0.999     0.973    297552
         1.0      0.331     0.005     0.010     16510

    accuracy                          0.947    314062
   macro avg      0.639     0.502     0.491    314062
weighted avg      0.915     0.947     0.922    314062

******* Overall Confusion Matrix *******
[[297382    170]
 [ 16426     84]]
******* Overall AUC Scores *******
Overall Macro AUC: 0.629
Overall Micro AUC: 0.760

******* Per-Class Metrics *******

Classification Report for Atelectasis:
              precision    recall  f1-score   support

         0.0      0.892     1.000     0.943     20013
         1.0      0.000     0.000     0.000      2420

    accuracy                          0.892     22433
   macro avg      0.446     0.500     0.471     22433
weighted avg      0.796     0.892     0.841     22433

Confusion Matrix:
[[20010     3]
 [ 2420     0]]
AUC for Atelect

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

AUC for Pneumonia: 0.605

Classification Report for Pneumothorax:
              precision    recall  f1-score   support

         0.0      0.951     1.000     0.975     21344
         1.0      0.000     0.000     0.000      1089

    accuracy                          0.951     22433
   macro avg      0.476     0.500     0.488     22433
weighted avg      0.905     0.951     0.928     22433

Confusion Matrix:
[[21344     0]
 [ 1089     0]]
AUC for Pneumothorax: 0.615

Classification Report for Consolidation:
              precision    recall  f1-score   support

         0.0      0.957     1.000     0.978     21476
         1.0      0.000     0.000     0.000       957

    accuracy                          0.957     22433
   macro avg      0.479     0.500     0.489     22433
weighted avg      0.916     0.957     0.936     22433

Confusion Matrix:
[[21476     0]
 [  957     0]]
AUC for Consolidation: 0.699

Classification Report for Edema:
              precision    recall  f1-score   sup

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [29]:
synthetic_dataset = SyntheticPneumoniaDataset("augmented/pneumonia")

for i in range(5):
    img, label = synthetic_dataset[i]
    print(f"Synthetic label {i}:", label)
    assert label.shape == (14,), "Label shape is wrong"
    assert label.dtype == torch.float32, "Label dtype is not float32"
    assert label[6] == 1 and label.sum() == 1, "Pneumonia not correctly one-hot"


Synthetic label 0: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])
Synthetic label 1: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])
Synthetic label 2: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])
Synthetic label 3: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])
Synthetic label 4: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])


In [31]:
real_dataset = TorchChestMNIST(ChestMNIST(split='train', transform=transform, size=64))

for i in range(5):
    img, label = real_dataset[i]
    print(f"Real label {i}:", label)
    assert isinstance(label, torch.Tensor), "Label is not tensor"
    assert label.shape == (14,), "Label shape is not (14,)"
    assert label.dtype == torch.float32, "Label dtype is not float"


Real label 0: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 1: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 2: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Real label 4: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [18]:
model.eval()
probabilities = []

for img, _ in synthetic_dataset:
    img_tensor = img.unsqueeze(0).to(device)  # shape: (1, 1, 64, 64)
    with torch.no_grad():
        output = torch.sigmoid(model(img_tensor))
    probabilities.append(output[0, 6].item())  # confidence for pneumonia



In [19]:
import numpy as np
print("Pneumonia confidence on synthetic images:")
print(f"Mean: {np.mean(probabilities):.4f}")
print(f"Max:  {np.max(probabilities):.4f}")
print(f"Min:  {np.min(probabilities):.4f}")


Pneumonia confidence on synthetic images:
Mean: 0.7308
Max:  0.7311
Min:  0.6114


In [20]:
model.eval()
val_probs = []
val_labels = []

for img, label in val_loader:
    img = img.to(device)
    label = label.to(device)

    with torch.no_grad():
        output = torch.sigmoid(model(img))  # keep sigmoid here if using logits

    val_probs.append(output[:, 6].cpu())     # pneumonia confidence
    val_labels.append(label[:, 6].cpu())     # ground truth pneumonia label


In [21]:
import torch
val_probs = torch.cat(val_probs).numpy()
val_labels = torch.cat(val_labels).numpy()

import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score

print("Pneumonia Confidence Stats:")
print(f"Mean (all): {val_probs.mean():.4f}")
print(f"Mean (pneumonia): {val_probs[val_labels == 1].mean():.4f}")
print(f"Mean (non-pneumonia): {val_probs[val_labels == 0].mean():.4f}")

print("\nAUC:", roc_auc_score(val_labels, val_probs))

# Optional: see how many cases pass 0.5 threshold
binary_preds = val_probs >= 0.5
precision = (binary_preds[val_labels == 1].sum()) / max(binary_preds.sum(), 1)
recall = (binary_preds[val_labels == 1].sum()) / max((val_labels == 1).sum(), 1)
f1 = 2 * precision * recall / max((precision + recall), 1e-8)

print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}")


Pneumonia Confidence Stats:
Mean (all): 0.5041
Mean (pneumonia): 0.5065
Mean (non-pneumonia): 0.5041

AUC: 0.5758370307873237
Precision: 0.012, Recall: 1.000, F1: 0.023


In [25]:
class ChestMNISTMultiLabelCNN(nn.Module):
    def __init__(self):
        super(ChestMNISTMultiLabelCNN, self).__init__()
        # Input shape: (B, 1, 64, 64)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # -> (B, 32, 64, 64)
        self.pool  = nn.MaxPool2d(2, 2)                           # -> halves spatial dims
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # -> (B, 64, 32, 32)
        # After another pooling, spatial size becomes 16x16: (B,64,16,16)
        self.fc1   = nn.Linear(64 * 16 * 16, 128)
        self.fc2   = nn.Linear(128, 14)  # 14 outputs for 14 labels
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (B, 32, 32, 32)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 64, 16, 16)
        x = x.view(x.size(0), -1)             # Flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # logits only!


model = ChestMNISTMultiLabelCNN().to(device)
print(model)

ChestMNISTMultiLabelCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=16384, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=14, bias=True)
)


In [26]:
# ---- Loss and optimizer ----
pos_weights = torch.ones(14).to(device)
pos_weights[6] = 20.0
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ---- Training loop ----
num_epochs = 50
patience = 3
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in augmented_train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device).float()  # <- make sure labels are float

        optimizer.zero_grad()
        outputs = model(inputs)  # model now returns raw logits
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        
    epoch_loss = running_loss / len(augmented_train_loader.dataset)

    # ---- Validation phase ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device).float()

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

    val_loss = val_loss / len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}")

    # ---- Early stopping ----
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print("Training complete.")


Epoch 1/50 - Train Loss: 0.2135 - Val Loss: 0.2189
Epoch 2/50 - Train Loss: 0.1937 - Val Loss: 0.2206
Epoch 3/50 - Train Loss: 0.1889 - Val Loss: 0.2383
Epoch 4/50 - Train Loss: 0.1840 - Val Loss: 0.2174
Epoch 5/50 - Train Loss: 0.1779 - Val Loss: 0.2349
Epoch 6/50 - Train Loss: 0.1701 - Val Loss: 0.2300
Epoch 7/50 - Train Loss: 0.1605 - Val Loss: 0.2494
Early stopping triggered after 7 epochs
Training complete.


In [28]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
import torch
import numpy as np

model.eval()
val_probs = []
val_labels = []

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs = inputs.to(device)
        targets = targets.to(device).float()

        outputs = torch.sigmoid(model(inputs))  # apply sigmoid manually
        val_probs.append(outputs[:, 6].cpu())   # pneumonia confidence
        val_labels.append(targets[:, 6].cpu())  # ground truth

# Combine all batches
val_probs = torch.cat(val_probs).numpy()
val_labels = torch.cat(val_labels).numpy()

# Metrics
auc = roc_auc_score(val_labels, val_probs)

binary_preds = (val_probs >= 0.5).astype(int)
tp = (binary_preds * val_labels).sum()
fp = ((binary_preds == 1) & (val_labels == 0)).sum()
fn = ((binary_preds == 0) & (val_labels == 1)).sum()

precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = 2 * precision * recall / max(precision + recall, 1e-8)

print("📊 Pneumonia Evaluation:")
print(f"AUC:      {auc:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")


📊 Pneumonia Evaluation:
AUC:      0.552
Precision: 0.005
Recall:    0.008
F1-score:  0.006


In [29]:
# ---- Loss and optimizer ----
pos_weights = torch.ones(14).to(device)
pos_weights[6] = 20.0
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ---- Training loop ----
num_epochs = 50
patience = 3
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device).float()  # <- make sure labels are float

        optimizer.zero_grad()
        outputs = model(inputs)  # model now returns raw logits
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        
    epoch_loss = running_loss / len(train_loader.dataset)

    # ---- Validation phase ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device).float()

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

    val_loss = val_loss / len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}")

    # ---- Early stopping ----
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print("Training complete.")


Epoch 1/50 - Train Loss: 0.1920 - Val Loss: 0.2097
Epoch 2/50 - Train Loss: 0.1767 - Val Loss: 0.2149
Epoch 3/50 - Train Loss: 0.1668 - Val Loss: 0.2284
Epoch 4/50 - Train Loss: 0.1599 - Val Loss: 0.2348
Early stopping triggered after 4 epochs
Training complete.


In [30]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
import torch
import numpy as np

model.eval()
val_probs = []
val_labels = []

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs = inputs.to(device)
        targets = targets.to(device).float()

        outputs = torch.sigmoid(model(inputs))  # apply sigmoid manually
        val_probs.append(outputs[:, 6].cpu())   # pneumonia confidence
        val_labels.append(targets[:, 6].cpu())  # ground truth

# Combine all batches
val_probs = torch.cat(val_probs).numpy()
val_labels = torch.cat(val_labels).numpy()

# Metrics
auc = roc_auc_score(val_labels, val_probs)

binary_preds = (val_probs >= 0.5).astype(int)
tp = (binary_preds * val_labels).sum()
fp = ((binary_preds == 1) & (val_labels == 0)).sum()
fn = ((binary_preds == 0) & (val_labels == 1)).sum()

precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = 2 * precision * recall / max(precision + recall, 1e-8)

print("📊 Pneumonia Evaluation:")
print(f"AUC:      {auc:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")


📊 Pneumonia Evaluation:
AUC:      0.633
Precision: 0.027
Recall:    0.120
F1-score:  0.045
