In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    classification_report,
    roc_curve,
    auc
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 16
EPOCHS = 30
LR = 3e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 7

NUM_CLASSES = 2
ATTN_DIM = 1024

DATA_ROOT = "/kaggle/input/breast-cancer-msi-multimodal-image-dataset/MultiModel Breast Cancer MSI Dataset"

MODALITIES = [
    "Chest_XRay_MSI",
    "Histopathological_MSI",
    "Ultrasound Images_MSI"
]

LABEL_MAP = {
    "benign": 0, "Benign": 0,
    "normal": 0, "Normal": 0,
    "malignant": 1, "Malignant": 1
}

# Dataset Description

class MSIDataset(Dataset):
    def __init__(self, root, modalities, transform=None):
        self.samples = []
        self.transform = transform

        for mod in modalities:
            mod_path = os.path.join(root, mod)
            for cls in os.listdir(mod_path):
                cls_path = os.path.join(mod_path, cls)
                if not os.path.isdir(cls_path):
                    continue
                label = LABEL_MAP.get(cls, -1)
                if label == -1:
                    continue
                for img in os.listdir(cls_path):
                    self.samples.append((os.path.join(cls_path, img), label, mod))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label, modality = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label, modality

# CBAM

class CBAM(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel)
        )
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)

    def forward(self, x):
        B, C, _, _ = x.shape

        avg = x.mean((2, 3))
        mx = torch.amax(x, (2, 3))
        attn = self.sigmoid(self.mlp(avg) + self.mlp(mx))
        x = x * attn.view(B, C, 1, 1)

        avg_map = x.mean(1, keepdim=True)
        max_map, _ = x.max(1, keepdim=True)
        spatial = self.sigmoid(self.conv(torch.cat([avg_map, max_map], 1)))

        return x * spatial

# Modality Specific Backbone

def build_backbone(mod):
    if mod == "Chest_XRay_MSI":
        m = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        return nn.Sequential(
            m.features,
            CBAM(1024),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        ), 1024

    if mod == "Histopathological_MSI":
        m = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
        return nn.Sequential(
            m.features,
            CBAM(1536),
            nn.AdaptiveAvgPool2d(1)
        ), 1536

    if mod == "Ultrasound Images_MSI":
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        return nn.Sequential(
            *list(m.children())[:-2],
            CBAM(2048),
            nn.AdaptiveAvgPool2d(1)
        ), 2048

# Gated Cross-Attention 

class GatedCrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads=1, batch_first=True)
        self.gate = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
        self.norm = nn.LayerNorm(dim)

    def forward(self, q, ctx):
        attn_out, _ = self.attn(q, ctx, ctx)
        gate = self.gate(attn_out)
        return self.norm(q + gate * attn_out)

# Multimodal Model

class MultimodalFusionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbones = nn.ModuleDict()
        self.proj = nn.ModuleDict()

        for mod in MODALITIES:
            bb, d = build_backbone(mod)
            self.backbones[mod] = bb
            self.proj[mod] = nn.Linear(d, ATTN_DIM)

        self.cross_attn = GatedCrossAttention(ATTN_DIM)
        self.classifier = nn.Sequential(
            nn.LayerNorm(ATTN_DIM),
            nn.Linear(ATTN_DIM, NUM_CLASSES)
        )

    def forward(self, x, mods):
        B = x.size(0)
        feats = torch.zeros(B, ATTN_DIM, device=x.device)

        for mod in MODALITIES:
            idx = [i for i, m in enumerate(mods) if m == mod]
            if idx:
                f = self.backbones[mod](x[idx]).flatten(1)
                f = self.proj[mod](f)
                feats[idx] = f

        fused = self.cross_attn(feats.unsqueeze(1), feats.unsqueeze(1)).squeeze(1)
        return self.classifier(fused)

# Training Function

def train(model, train_dl, val_dl):
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    train_loss, val_loss = [], []
    train_acc, val_acc = [], []

    best_val, wait = 0, 0

    for ep in range(EPOCHS):
        model.train()
        preds, gts, losses = [], [], []

        for x, y, m in train_dl:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()

            out = model(x, m)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            preds += out.argmax(1).cpu().tolist()
            gts += y.cpu().tolist()

        tr_acc = accuracy_score(gts, preds) * 100
        train_loss.append(np.mean(losses))
        train_acc.append(tr_acc)

        model.eval()
        preds, gts, losses = [], [], []
        with torch.no_grad():
            for x, y, m in val_dl:
                out = model(x.to(DEVICE), m)
                loss = criterion(out, y.to(DEVICE))
                losses.append(loss.item())
                preds += out.argmax(1).cpu().tolist()
                gts += y.tolist()

        va_acc = accuracy_score(gts, preds) * 100
        val_loss.append(np.mean(losses))
        val_acc.append(va_acc)

        print(f"Epoch {ep+1}/{EPOCHS} | Train Acc {tr_acc:.2f}% | Val Acc {va_acc:.2f}%")

        if va_acc > best_val:
            best_val = va_acc
            wait = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            wait += 1
            if wait >= PATIENCE:
                print(f"\nEarly stopping | Best Val Acc = {best_val:.2f}%")
                break

    return train_loss, val_loss, train_acc, val_acc

def plot_curves(tl, vl, ta, va):
    epochs = range(1, len(tl) + 1)

    plt.figure()
    plt.plot(epochs, tl, label="Train Loss")
    plt.plot(epochs, vl, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(epochs, ta, label="Train Accuracy")
    plt.plot(epochs, va, label="Val Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title("Training & Validation Accuracy")
    plt.legend()
    plt.show()

# Evaluation Function

def evaluate(model, val_dl):
    model.eval()
    preds, labels, probs = [], [], []

    with torch.no_grad():
        for x, y, m in val_dl:
            out = model(x.to(DEVICE), m)
            p = torch.softmax(out, dim=1)[:, 1]
            preds += out.argmax(1).cpu().tolist()
            labels += y.tolist()
            probs += p.cpu().tolist()

    cm = confusion_matrix(labels, preds)
    sns.heatmap(cm, annot=True, fmt="d",
                xticklabels=["Benign/Normal", "Malignant"],
                yticklabels=["Benign/Normal", "Malignant"])
    plt.title("Confusion Matrix")
    plt.show()

    print(classification_report(labels, preds))

    fpr, tpr, _ = roc_curve(labels, probs)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
    plt.plot([0,1],[0,1],'--')
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.title("ROC Curve")
    plt.legend()
    plt.show()

# Main Function

def main():
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485,0.456,0.406],
            std=[0.229,0.224,0.225]
        )
    ])

    dataset = MSIDataset(DATA_ROOT, MODALITIES, transform)
    t = int(0.8 * len(dataset))
    v = len(dataset) - t
    train_ds, val_ds = random_split(dataset, [t, v])

    collate = lambda b: (
        torch.stack([x[0] for x in b]),
        torch.tensor([x[1] for x in b]),
        [x[2] for x in b]
    )

    train_dl = DataLoader(train_ds, BATCH_SIZE, True, collate_fn=collate)
    val_dl = DataLoader(val_ds, BATCH_SIZE, False, collate_fn=collate)

    model = MultimodalFusionNet().to(DEVICE)
    tl, vl, ta, va = train(model, train_dl, val_dl)

    plot_curves(tl, vl, ta, va)

    model.load_state_dict(torch.load("best_model.pth"))
    evaluate(model, val_dl)

if __name__ == "__main__":
    main()
