<a href="https://colab.research.google.com/github/UbaidullahTanoli/CNN-MLP/blob/main/CNN%2BMLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import files
files.upload()

In [3]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d raddar/chest-xrays-indiana-university
!unzip chest-xrays-indiana-university.zip -d /content/dataset/

In [14]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import EfficientNet_B1_Weights
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from sklearn.model_selection import KFold

In [15]:
# 1. Custom Dataset that Merges Two CSVs
# ================================
class PureCNNMergedDataset(Dataset):
    def __init__(self, reports_csv, proj_csv, image_folder, transform=None):
        """
        Args:
            reports_csv (str): Path to 'indiana_reports.csv', which includes the "MeSH" column for labels.
            proj_csv (str): Path to 'indiana_projections.csv', which maps uid to image filename.
            image_folder (str): Directory containing the image files.
            transform (callable, optional): Transform to apply on images.
        """
        # Load both CSVs
        self.reports_df = pd.read_csv(reports_csv)
        self.proj_df = pd.read_csv(proj_csv)
        # Merge on 'uid'
        self.data = pd.merge(self.reports_df, self.proj_df, on='uid')
        self.image_folder = image_folder
        self.transform = transform

        # Store all labels for class balancing
        self.labels = []
        for idx in range(len(self.data)):
            row = self.data.iloc[idx]
            mesh_val = str(row['MeSH']).strip().lower()
            label = 0 if mesh_val == 'normal' else 1
            self.labels.append(label)
        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        # Get the image filename from the projections CSV merged with reports CSV
        filename = row['filename']  # e.g., "1_IM-0001-4001.dcm.png"
        img_path = os.path.join(self.image_folder, filename)

        # Load the image and convert to RGB if needed
        pil_image = Image.open(img_path)
        if pil_image.mode != 'RGB':
            pil_image = pil_image.convert('RGB')
        if self.transform:
            image_tensor = self.transform(pil_image)
        else:
            image_tensor = transforms.ToTensor()(pil_image)

        # Derive binary label from the "MeSH" column in the reports CSV
        # e.g., if MeSH equals "normal" (case-insensitive), label = 0; else, label = 1.
        mesh_val = str(row['MeSH']).strip().lower()
        label = 0 if mesh_val == 'normal' else 1

        return image_tensor, label

    def get_class_distribution(self):
        """
        Returns the count of each class in the dataset.
        """
        class_counts = {}
        for label in self.labels:
            if label in class_counts:
                class_counts[label] += 1
            else:
                class_counts[label] = 1
        return class_counts

In [16]:
# 2. Data Transforms with Augmentation
# ================================
# Define transforms for training with augmentation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Define transforms for validation (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

In [17]:
# 3. Enhanced CNN Model with MLP, Batch Normalization
# ================================
class PureCNNModel(nn.Module):
    def __init__(self, num_classes=2):
        super(PureCNNModel, self).__init__()

        # Load EfficientNet-B1 pretrained on ImageNet
        self.backbone = models.efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1)

        # Get the input features of the classifier (EfficientNet-B1 typically uses 1280)
        in_features = self.backbone.classifier[1].in_features

        # Replace the final classifier with an MLP including BatchNorm
        self.backbone.classifier[1] = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

        # Freeze all backbone feature layers initially
        for param in self.backbone.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.backbone(x)

    def unfreeze_last_block(self):
        """
        Unfreeze the last block of the EfficientNet backbone for fine-tuning
        """
        for param in self.backbone.features[-1].parameters():
            param.requires_grad = True

    def unfreeze_last_n_blocks(self, n=2):
        """
        Unfreeze the last n blocks of the EfficientNet backbone
        """
        for i in range(n):
            block_idx = len(self.backbone.features) - 1 - i
            if block_idx >= 0:
                for param in self.backbone.features[block_idx].parameters():
                    param.requires_grad = True

In [21]:
# 4. Training and Evaluation Functions
# ================================
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    preds_all, labels_all = [], []
    probs_all = []  # Added for ROC AUC calculation

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        # Get probabilities for positive class (needed for AUC)
        probs = torch.softmax(outputs, dim=1)[:, 1]

        _, preds = torch.max(outputs, dim=1)
        preds_all.extend(preds.cpu().numpy())
        labels_all.extend(labels.cpu().numpy())
        probs_all.extend(probs.detach().cpu().numpy())

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(labels_all, preds_all)
    epoch_prec = precision_score(labels_all, preds_all, zero_division=0)
    epoch_rec = recall_score(labels_all, preds_all, zero_division=0)
    epoch_f1 = f1_score(labels_all, preds_all, zero_division=0)

    # Add AUC metrics
    try:
        epoch_roc_auc = roc_auc_score(labels_all, probs_all)
        epoch_pr_auc = average_precision_score(labels_all, probs_all)
    except:
        # Handle edge cases where all labels might be from the same class
        epoch_roc_auc = 0.0
        epoch_pr_auc = 0.0

    return epoch_loss, epoch_acc, epoch_prec, epoch_rec, epoch_f1, epoch_roc_auc, epoch_pr_auc

def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    preds_all, labels_all = [], []
    probs_all = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            # Get class probabilities for positive class (needed for AUC)
            probs = torch.softmax(outputs, dim=1)[:, 1]

            _, preds = torch.max(outputs, dim=1)
            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(labels.cpu().numpy())
            probs_all.extend(probs.detach().cpu().numpy())  # Fixed: changed label to probs

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(labels_all, preds_all)
    epoch_prec = precision_score(labels_all, preds_all, zero_division=0)
    epoch_rec = recall_score(labels_all, preds_all, zero_division=0)
    epoch_f1 = f1_score(labels_all, preds_all, zero_division=0)

    # Add AUC metrics
    try:
        epoch_roc_auc = roc_auc_score(labels_all, probs_all)
        epoch_pr_auc = average_precision_score(labels_all, probs_all)
    except:
        # Handle edge cases where all labels might be from the same class
        epoch_roc_auc = 0.0
        epoch_pr_auc = 0.0

    return epoch_loss, epoch_acc, epoch_prec, epoch_rec, epoch_f1, epoch_roc_auc, epoch_pr_auc

In [23]:
# 5. Cross-Validation Implementation
# ================================
def cross_validate(reports_csv, proj_csv, image_folder, num_epochs=40, k=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    kf = KFold(n_splits=k, shuffle=True, random_state=42)

    # Load the full dataset
    full_dataset = PureCNNMergedDataset(reports_csv, proj_csv, image_folder, transform=None)

    # Check class distribution
    class_distribution = full_dataset.get_class_distribution()
    print(f"Class distribution: {class_distribution}")

    # Calculate class weights for weighted loss
    num_samples = len(full_dataset)
    class_weights = torch.FloatTensor([num_samples / (len(class_distribution) * count)
                                       for count in class_distribution.values()])

    all_results = []

    # For each fold
    for fold, (train_idx, val_idx) in enumerate(kf.split(range(len(full_dataset)))):
        print(f"\nTraining Fold {fold+1}/{k}")

        # Create datasets with appropriate transforms
        train_dataset = PureCNNMergedDataset(reports_csv, proj_csv, image_folder, transform=train_transform)
        val_dataset = PureCNNMergedDataset(reports_csv, proj_csv, image_folder, transform=val_transform)

        # Create subset samplers
        train_subsampler = SubsetRandomSampler(train_idx)
        val_subsampler = SubsetRandomSampler(val_idx)

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=8, sampler=train_subsampler, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=8, sampler=val_subsampler, num_workers=2)

        # Initialize the model
        model = PureCNNModel(num_classes=2).to(device)

        # Create criterion with class weights
        criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

        # Initial training phase: Only train the MLP head
        optimizer = optim.AdamW(model.backbone.classifier.parameters(), lr=1e-3, weight_decay=0.01)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1,
                                                        patience=3)

        # First phase: Train only the MLP head for 15 epochs
        print("Phase 1: Training only the MLP head...")
        for epoch in range(15):
            train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc, train_pr_auc = train_epoch(
                model, train_loader, criterion, optimizer, device
            )
            val_loss, val_acc, val_prec, val_rec, val_f1, val_roc_auc, val_pr_auc = eval_epoch(
                model, val_loader, criterion, device
            )

            print(f"Epoch {epoch+1}/15")
            print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | ROC AUC: {train_roc_auc:.4f} | PR AUC: {train_pr_auc:.4f}")
            print(f"  Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | ROC AUC: {val_roc_auc:.4f} | PR AUC: {val_pr_auc:.4f}")

            scheduler.step(val_loss)
            current_lr = scheduler.get_last_lr()
            print(f"Epoch {epoch+1}/15 - Current learning rate: {current_lr}")

        # Second phase: Unfreeze the last block and train with a lower learning rate
        print("\nPhase 2: Unfreezing last block and fine-tuning...")
        model.unfreeze_last_block()

        # Create new optimizer with different learning rates for different parts
        optimizer = optim.AdamW([
            {'params': model.backbone.features[-1].parameters(), 'lr': 1e-4},
            {'params': model.backbone.classifier.parameters(), 'lr': 1e-3}
        ], weight_decay=0.01)

        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1,
                                                        patience=3)

        best_val_f1 = 0
        for epoch in range(num_epochs - 15):
            train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc, train_pr_auc = train_epoch(
                model, train_loader, criterion, optimizer, device
            )
            val_loss, val_acc, val_prec, val_rec, val_f1, val_roc_auc, val_pr_auc = eval_epoch(
                model, val_loader, criterion, device
            )

            print(f"Epoch {epoch+1}/{num_epochs-15}")
            print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | ROC AUC: {train_roc_auc:.4f} | PR AUC: {train_pr_auc:.4f}")
            print(f"  Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | ROC AUC: {val_roc_auc:.4f} | PR AUC: {val_pr_auc:.4f}")
            scheduler.step(val_loss)

            current_lr = scheduler.get_last_lr()
            print(f"Epoch {epoch+1}/15 - Current learning rate: {current_lr}")

            # Save the best model for this fold
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), f"best_model_fold_{fold}.pt")
                print(f"  Saved new best model with validation F1: {val_f1:.4f}")

        # Load the best model for final evaluation
        model.load_state_dict(torch.load(f"best_model_fold_{fold}.pt"))
        _, fold_acc, fold_prec, fold_rec, fold_f1, fold_roc_auc, fold_pr_auc = eval_epoch(model, val_loader, criterion, device)
        fold_results = {
            'fold': fold + 1,
            'accuracy': fold_acc,
            'precision': fold_prec,
            'recall': fold_rec,
            'f1': fold_f1,
            'roc_auc': fold_roc_auc,
            'pr_auc': fold_pr_auc
        }

        all_results.append(fold_results)

        print(f"\nFold {fold+1} Results:")
        for metric, value in fold_results.items():
            if metric != 'fold':
                print(f"  {metric}: {value:.4f}")

    # Calculate and print average metrics across all folds
    avg_metrics = {metric: sum(result[metric] for result in all_results) / k
                for metric in ['accuracy', 'precision', 'recall', 'f1', 'roc_auc', 'pr_auc']}

    print("\nAverage Results Across All Folds:")
    for metric, value in avg_metrics.items():
        print(f"  {metric}: {value:.4f}")

    return all_results, avg_metrics

In [24]:
# 6. Standard Training Function (for comparison)
# ================================
def train_model(reports_csv, proj_csv, image_folder, num_epochs=40):
    # Create the merged dataset with train augmentation
    train_dataset = PureCNNMergedDataset(reports_csv, proj_csv, image_folder, transform=train_transform)

    # Check class distribution for weighted loss
    class_distribution = train_dataset.get_class_distribution()
    print(f"Class distribution: {class_distribution}")

    # Calculate class weights
    num_samples = len(train_dataset)
    class_weights = torch.FloatTensor([num_samples / (len(class_distribution) * count)
                                      for count in class_distribution.values()])

    # Split the dataset (80% training, 20% testing)
    total_size = len(train_dataset)
    train_size = int(0.8 * total_size)
    test_size = total_size - train_size
    train_subset, test_subset = random_split(train_dataset, [train_size, test_size])

    # Create test dataset with test transform
    test_dataset = PureCNNMergedDataset(reports_csv, proj_csv, image_folder, transform=val_transform)

    # We need to use the same indices for the test set
    test_indices = test_subset.indices
    test_subset = torch.utils.data.Subset(test_dataset, test_indices)

    # Create DataLoaders
    train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_subset, batch_size=8, shuffle=False, num_workers=2)

    # Set device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Instantiate the model
    model = PureCNNModel(num_classes=2).to(device)

    # Define loss function with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    # First phase: Train only the classifier
    print("Phase 1: Training only the MLP head...")
    optimizer = optim.AdamW(model.backbone.classifier.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1)

    for epoch in range(15):
        train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc, train_pr_auc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        test_loss, test_acc, test_prec, test_rec, test_f1, test_roc_auc, test_pr_auc = eval_epoch(
            model, test_loader, criterion, device
        )

        print(f"Epoch {epoch+1}/15")
        print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | ROC AUC: {train_roc_auc:.4f} | PR AUC: {train_pr_auc:.4f}")
        print(f"  Test  Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | F1: {test_f1:.4f} | ROC AUC: {test_roc_auc:.4f} | PR AUC: {test_pr_auc:.4f}")

        scheduler.step(test_loss)
        current_lr = scheduler.get_last_lr()
        print(f"Epoch {epoch+1} - Current learning rate: {current_lr}")


    # Second phase: Unfreeze the last block and continue training
    print("\nPhase 2: Unfreezing last block and fine-tuning...")
    model.unfreeze_last_block()

    # Create optimizer with different learning rates
    optimizer = optim.AdamW([
        {'params': model.backbone.features[-1].parameters(), 'lr': 1e-4},
        {'params': model.backbone.classifier.parameters(), 'lr': 1e-3}
    ], weight_decay=0.01)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1)

    best_test_f1 = 0
    for epoch in range(num_epochs - 15):
        train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc, train_pr_auc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        test_loss, test_acc, test_prec, test_rec, test_f1, test_roc_auc, test_pr_auc = eval_epoch(
            model, test_loader, criterion, device
        )

        print(f"Epoch {epoch+1}/{num_epochs-15}")
        print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | ROC AUC: {train_roc_auc:.4f} | PR AUC: {train_pr_auc:.4f}")
        print(f"  Test  Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | F1: {test_f1:.4f} | ROC AUC: {test_roc_auc:.4f} | PR AUC: {test_pr_auc:.4f}")

        scheduler.step(test_loss)
        current_lr = scheduler.get_last_lr()
        print(f"Epoch {epoch+1} - Current learning rate: {current_lr}")


        # Save best model
        if test_f1 > best_test_f1:
            best_test_f1 = test_f1
            torch.save(model.state_dict(), "best_model.pt")
            print(f"  Saved new best model with test F1: {test_f1:.4f}")

    # Load best model for final evaluation
    model.load_state_dict(torch.load("best_model.pt"))
    final_loss, final_acc, final_prec, final_rec, final_f1, final_roc_auc, final_pr_auc = eval_epoch(model, test_loader, criterion, device)

    print("\nFinal Model Performance:")
    print(f"  Accuracy: {final_acc:.4f}")
    print(f"  Precision: {final_prec:.4f}")
    print(f"  Recall: {final_rec:.4f}")
    print(f"  F1 Score: {final_f1:.4f}")
    print(f"  ROC AUC: {final_roc_auc:.4f}")
    print(f"  PR AUC: {final_pr_auc:.4f}")

    return model, {
        "accuracy": final_acc,
        "precision": final_prec,
        "recall": final_rec,
        "f1": final_f1,
        "roc_auc": final_roc_auc,
        "pr_auc": final_pr_auc
    }

In [25]:
# 7. Main Function
# ================================
def main():
    # Paths to the CSV files and image folder
    reports_csv = '/content/dataset/indiana_reports.csv'
    proj_csv = '/content/dataset/indiana_projections.csv'
    image_folder = '/content/dataset/images/images_normalized'

    # Decide whether to use cross-validation or standard training
    use_cross_validation = True

    if use_cross_validation:
        print("Running 5-fold cross-validation...")
        results, avg_metrics = cross_validate(reports_csv, proj_csv, image_folder, num_epochs=40, k=5)
    else:
        print("Running standard training with 80/20 split...")
        model, metrics = train_model(reports_csv, proj_csv, image_folder, num_epochs=40)

In [26]:
if __name__ == "__main__":
    main()

Running 5-fold cross-validation...
Class distribution: {np.int64(0): 2695, np.int64(1): 4771}

Training Fold 1/5
Phase 1: Training only the MLP head...
Epoch 1/15
  Train Loss: 0.5545 | Acc: 0.5881 | F1: 0.6613 | ROC AUC: 0.6101 | PR AUC: 0.7294
  Val   Loss: 0.1300 | Acc: 0.6339 | F1: 0.6887 | ROC AUC: 0.6723 | PR AUC: 0.7574
Epoch 1/15 - Current learning rate: [0.001]
Epoch 2/15
  Train Loss: 0.5325 | Acc: 0.6202 | F1: 0.6919 | ROC AUC: 0.6437 | PR AUC: 0.7574
  Val   Loss: 0.1286 | Acc: 0.6118 | F1: 0.6393 | ROC AUC: 0.6820 | PR AUC: 0.7582
Epoch 2/15 - Current learning rate: [0.001]
Epoch 3/15
  Train Loss: 0.5296 | Acc: 0.6114 | F1: 0.6796 | ROC AUC: 0.6532 | PR AUC: 0.7682
  Val   Loss: 0.1304 | Acc: 0.6432 | F1: 0.6987 | ROC AUC: 0.6788 | PR AUC: 0.7525
Epoch 3/15 - Current learning rate: [0.001]
Epoch 4/15
  Train Loss: 0.5208 | Acc: 0.6351 | F1: 0.7031 | ROC AUC: 0.6661 | PR AUC: 0.7785
  Val   Loss: 0.1283 | Acc: 0.6560 | F1: 0.7191 | ROC AUC: 0.6923 | PR AUC: 0.7661
Epoch 4/

KeyboardInterrupt: 