This pipeline combines DenseNet-121 as the CNN backbone, Swin Transformer-Tiny as the Vision Transformer branch, CBAM for the multi-scale attention module, and LIME for post-hoc explainability, creating a novel and interpretable framework for fetal brain abnormality detection from ultrasound images.

**Pipeline Summary**

**Data Acquisition and Preprocessing**

Uses a fetal brain ultrasound dataset (1,768 images, e.g., Roboflow) across normal and abnormal classes. Preprocessing includes resizing (for CNN/ViT compatibility), normalization, augmentations (to address small data size), and appropriate train-validation-test splits.

**CNN Backbone: DenseNet-121**

DenseNet-121 extracts local spatial features such as subtle textural changes, anatomical boundaries, and brain structural outlines from ultrasound scans, enhancing detection of localized abnormalities.

**Vision Transformer Branch: Swin Transformer-Tiny**

Swin Transformer-Tiny provides global contextual understanding, capturing long-range dependencies (such as overall skull shape, large lesions, or abnormal cavity enlargement) that complement the CNN’s localized focus.

**Multi-Scale Attention with CBAM**

The Convolutional Block Attention Module (CBAM) refines features at both channel and spatial levels, emphasizing image regions and feature groups critical for distinguishing similar abnormalities. This ensures robust attention to both coarse and fine anatomical structures across multiple resolutions.

**Feature Fusion**

Features from CNN (DenseNet-121) and ViT (Swin Transformer-Tiny) branches are fused (often via concatenation and another multi-scale attention layer) and passed through dense layers to produce unified representations for classification.

**Classification**

Unified features are processed via fully connected, dropout, and softmax layers for multi-class prediction: normal or subtypes such as ventriculomegaly, encephalocele, holoprosencephaly, and hemorrhage.

**Explainability with LIME**

LIME (Local Interpretable Model-agnostic Explanations) is used post-classification to generate pixel-level, visually interpretable heatmaps, pinpointing which image areas influenced each decision. This satisfies the need for model transparency, regulatory compliance, and clinician trust—key requirements in medical AI applications.

**Key Novelty:**

Hybrid CNN-ViT architecture enables robust detection across local and global feature scales.

Multi-scale attention (CBAM) ensures adaptive focus on medically relevant regions.

LIME provides granular, model-agnostic explainability, addressing the research gap in transparent, clinically deployable fetal abnormality detection.

In [2]:
# Mount data from drive

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pip install torch torchvision timm scikit-learn pandas pillow tqdm opencv-python matplotlib seaborn lime

Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=5a93a901b3c1cedb783e4d748aed9e928a65ea3341ef74493f46003372dbccaf
  Stored in directory: /root/.cache/pip/wheels/e7/5d/0e/4b4fff9a47468fed5633211fb3b76d1db43fe806a17fb7486a
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1


In [1]:
"""
fetal_brain_pipeline_multiclass.py

End-to-end pipeline (multi-class):
- dataset loader (reads <split>/_classes.csv with one-hot labels)
- model: DenseNet-121 (CNN) + Swin-Tiny (ViT) + CBAM (attention) + fusion head
- training loop (CrossEntropyLoss)
- validation/test metrics (per-class AUC, F1 macro/micro)
- LIME explanation generation (save heatmaps)

Author: ChatGPT
"""

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm

from sklearn.metrics import roc_auc_score, f1_score, precision_recall_fscore_support
from sklearn.preprocessing import label_binarize

# LIME imports
from lime import lime_image
from skimage.segmentation import mark_boundaries

# -------------------------
# Config / hyperparameters
# -------------------------
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DATA_ROOT = "/content/drive/MyDrive/Amrita/Sem5/ML/Classification_Dataset/"
BATCH_SIZE = 16
IMAGE_SIZE = 224
NUM_EPOCHS = 10
LR = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 6
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_SAVE_PATH = "best_model_multiclass.pth"
NUM_CLASSES = 11  # update if different; matches your CSV header count

# -------------------------
# Helper: read CSV, get classes
# -------------------------
def read_classes_from_csv(csv_path: str):
    df = pd.read_csv(csv_path, sep=None, engine="python")
    class_cols = list(df.columns[1:])
    return class_cols

# -------------------------
# Dataset
# -------------------------
class FetalUSDataset(Dataset):
    def __init__(self, split_dir: str, image_size=224, transform=None):
        csv_path = os.path.join(split_dir, "_classes.csv")
        assert os.path.exists(csv_path), f"CSV not found: {csv_path}"
        self.df = pd.read_csv(csv_path, sep=None, engine="python")
        self.dir = split_dir
        self.filenames = self.df.iloc[:, 0].astype(str).values
        self.labels = self.df.iloc[:, 1:].astype(int).values
        self.transform = transform if transform else self.default_transform(image_size)
        self.image_size = image_size

    def default_transform(self, image_size):
        return T.Compose([
            T.Resize((image_size, image_size)),
            T.Grayscale(num_output_channels=3),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fn = self.filenames[idx]
        img_path = os.path.join(self.dir, fn)
        if not os.path.exists(img_path):
            alt = os.path.join(self.dir, "images", fn)
            if os.path.exists(alt):
                img_path = alt
            else:
                raise FileNotFoundError(f"{img_path} missing")
        img = Image.open(img_path).convert("RGB")
        x = self.transform(img)
        y = torch.tensor(np.argmax(self.labels[idx]), dtype=torch.long)  # convert one-hot → class index
        return x, y, fn

# -------------------------
# CBAM implementation
# -------------------------
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, bias=False):
        super().__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = 3 if kernel_size==7 else 1
        self.conv = BasicConv(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(in_planes, ratio)
        self.spatial_att = SpatialAttention(kernel_size)
    def forward(self, x):
        x_out = x * self.channel_att(x)
        return x_out * self.spatial_att(x_out)

# -------------------------
# HybridNet: DenseNet + CBAM + Swin + Fusion
# -------------------------
class HybridNet(nn.Module):
    def __init__(self, densenet_model_name="densenet121", swin_model_name="swin_base_patch4_window7_224", num_classes=11):
        super().__init__()

        # CNN branch
        dnet = timm.create_model(densenet_model_name, pretrained=True)
        self.cnn_features = dnet.features
        self.cbam = CBAM(dnet.num_features)
        self.cnn_gap = nn.AdaptiveAvgPool2d((1,1))
        cnn_feat_dim = dnet.num_features

        # Swin branch
        self.swin = timm.create_model(swin_model_name, pretrained=True, num_classes=0, global_pool="avg")
        swin_feat_dim = self.swin.num_features
        self.swin_pool = nn.AdaptiveAvgPool2d((1,1))

        # Fusion + classifier
        fusion_in_features = cnn_feat_dim + swin_feat_dim  # automatically 1024 + 1007
        self.classifier = nn.Sequential(
            nn.Linear(fusion_in_features, 512),  # use correct input dim
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # CNN
        cnn_feat = self.cnn_features(x)
        cnn_feat = self.cbam(cnn_feat)
        cnn_vec = torch.flatten(self.cnn_gap(cnn_feat), 1)
        # Swin
        swin_feat = self.swin.forward_features(x)
        swin_vec = torch.flatten(self.swin_pool(swin_feat), 1)
        # Fusion
        fused = torch.cat([cnn_vec, swin_vec], dim=1)
        fused = self.fusion_dropout(fused)
        logits = self.classifier(fused)
        return logits

# -------------------------
# Training / validation
# -------------------------
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for imgs, labels, _ in tqdm(loader, desc="Train batches"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    return running_loss / len(loader.dataset)

@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    all_logits, all_labels, filenames = [], [], []
    for imgs, labels, fns in tqdm(loader, desc="Validation batches"):
        imgs = imgs.to(device)
        logits = model(imgs)
        all_logits.append(logits.detach().cpu().numpy())
        all_labels.append(labels.numpy())
        filenames.extend(fns)
    all_logits = np.vstack(all_logits)
    all_labels = np.hstack(all_labels)
    return all_labels, all_logits, filenames

# -------------------------
# LIME explainer
# -------------------------
class LimeExplainerWrapper:
    def __init__(self, model, class_names, device="cpu", transform=None):
        self.model = model
        self.model.eval()
        self.class_names = class_names
        self.device = device
        self.transform = transform if transform else T.Compose([
            T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
    def predict_proba(self, imgs_rgb):
        batch = []
        for im in imgs_rgb:
            pil = Image.fromarray(im.astype(np.uint8))
            x = self.transform(pil)
            batch.append(x)
        x = torch.stack(batch).to(self.device)
        with torch.no_grad():
            logits = self.model(x)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
        return probs
    def explain_image(self, img_path, label_idx, top_labels=1, num_samples=1000, hide_color=0):
        explainer = lime_image.LimeImageExplainer()
        image = np.array(Image.open(img_path).convert("RGB"))
        fn = lambda imgs: self.predict_proba(np.array(imgs))
        explanation = explainer.explain_instance(image, fn, top_labels=top_labels, hide_color=hide_color, num_samples=num_samples)
        temp, mask = explanation.get_image_and_mask(label_idx, positive_only=True, num_features=5, hide_rest=False)
        vis = mark_boundaries(temp/255.0, mask)
        return vis, explanation

# -------------------------
# Main
# -------------------------
def main():
    # splits
    train_dir, valid_dir, test_dir = [os.path.join(DATA_ROOT, s) for s in ["train","valid","test"]]
    class_names = read_classes_from_csv(os.path.join(train_dir,"_classes.csv"))
    global NUM_CLASSES
    NUM_CLASSES = len(class_names)
    print("Classes:", class_names)

    # transforms
    train_transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ColorJitter(0.1,0.1),
        T.Grayscale(3),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    val_transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.Grayscale(3),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_ds = FetalUSDataset(train_dir, IMAGE_SIZE, train_transform)
    val_ds = FetalUSDataset(valid_dir, IMAGE_SIZE, val_transform)
    test_ds = FetalUSDataset(test_dir, IMAGE_SIZE, val_transform)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    # model, criterion, optimizer
    model = HybridNet(num_classes=NUM_CLASSES).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    best_val_f1 = -1.0
    for epoch in range(1, NUM_EPOCHS+1):
        print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        print(f"Train loss: {train_loss:.4f}")

        val_labels, val_logits, _ = validate(model, val_loader, DEVICE)
        val_preds = np.argmax(val_logits, axis=1)
        f1_macro = f1_score(val_labels, val_preds, average="macro")
        f1_micro = f1_score(val_labels, val_preds, average="micro")
        print(f"Val F1 macro: {f1_macro:.4f} | F1 micro: {f1_micro:.4f}")

        if f1_macro > best_val_f1:
            best_val_f1 = f1_macro
            torch.save({"model_state": model.state_dict(), "class_names": class_names, "epoch": epoch}, MODEL_SAVE_PATH)
            print("Saved best model.")

        scheduler.step()

    # Test evaluation
    ckpt = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    test_labels, test_logits, test_fns = validate(model, test_loader, DEVICE)
    test_preds = np.argmax(test_logits, axis=1)
    f1_macro = f1_score(test_labels, test_preds, average="macro")
    f1_micro = f1_score(test_labels, test_preds, average="micro")
    print("Test F1 macro:", f1_macro, "F1 micro:", f1_micro)

    # Save test predictions
    test_probs = torch.softmax(torch.tensor(test_logits), dim=1).numpy()
    out_df = pd.DataFrame(test_probs, columns=class_names)
    out_df["filename"] = test_fns
    out_df = out_df[["filename"] + class_names]
    out_df.to_csv("test_predictions_multiclass.csv", index=False)
    print("Saved test_predictions_multiclass.csv")

    # Example LIME explanation
    explainer = LimeExplainerWrapper(model, class_names, DEVICE, val_transform)
    target_class_name = class_names[0]  # e.g., first class
    idx = class_names.index(target_class_name)
    for i in range(len(test_ds)):
        _, label_idx, fn = test_ds[i]
        if label_idx == idx:
            img_path = os.path.join(test_dir, fn)
            vis, _ = explainer.explain_image(img_path, label_idx=idx, num_samples=800)
            out_img = f"lime_{os.path.basename(fn)}_{target_class_name}.png"
            plt.imsave(out_img, vis)
            print("Saved LIME visualization:", out_img)
            break

if __name__ == "__main__":
    main()


Classes: [' anold-chiari-malformation', ' arachnoid-cyst', ' cerebellah-hypoplasia', ' colphocephaly', ' encephalocele', ' holoprosencephaly', ' hydracenphaly', ' intracranial-hemorrdge', ' intracranial-tumor', ' m-magna', ' mild-ventriculomegaly', ' moderate-ventriculomegaly', ' normal', ' polencephaly', ' severe-ventriculomegaly', ' vein-of-galen']


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



Epoch 1/10


Train batches:   0%|          | 0/89 [00:47<?, ?it/s]


AttributeError: 'HybridNet' object has no attribute 'fusion_dropout'