In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_AMP = DEVICE == "cuda"

BATCH_SIZE = 16
EPOCHS = 30
LR = 3e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 6

NUM_CLASSES = 2
ATTN_DIM = 1024

DATA_ROOT = "/kaggle/input/breast-cancer-msi-multimodal-image-dataset/MultiModel Breast Cancer MSI Dataset"

MODALITIES = [
    "Chest_XRay_MSI",
    "Histopathological_MSI",
    "Ultrasound Images_MSI"
]

LABEL_MAP = {
    "benign": 0, "Benign": 0,
    "normal": 0, "Normal": 0,
    "malignant": 1, "Malignant": 1
}

# Dataset Description

class MSIDataset(Dataset):
    def __init__(self, root, modalities, transform=None):
        self.samples = []
        self.transform = transform

        for mod in modalities:
            mod_path = os.path.join(root, mod)
            for cls in os.listdir(mod_path):
                cls_path = os.path.join(mod_path, cls)
                if not os.path.isdir(cls_path):
                    continue
                label = LABEL_MAP.get(cls, -1)
                if label == -1:
                    continue
                for img in os.listdir(cls_path):
                    self.samples.append((os.path.join(cls_path, img), label, mod))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label, modality = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label, modality

# Modality Specific Backbone

def build_backbone(modality):
    if modality == "Chest_XRay_MSI":
        m = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        backbone = nn.Sequential(
            m.features,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        feat_dim = 1024

    elif modality == "Histopathological_MSI":
        m = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
        backbone = nn.Sequential(
            m.features,
            nn.AdaptiveAvgPool2d(1)
        )
        feat_dim = 1536

    elif modality == "Ultrasound Images_MSI":
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        backbone = nn.Sequential(
            *list(m.children())[:-2],
            nn.AdaptiveAvgPool2d(1)
        )
        feat_dim = 2048

    else:
        raise ValueError("Unknown modality")

    return backbone, feat_dim

# Multimodal Gated Fusion

class MultimodalFusionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbones = nn.ModuleDict()
        self.projections = nn.ModuleDict()

        for mod in MODALITIES:
            bb, dim = build_backbone(mod)
            self.backbones[mod] = bb
            self.projections[mod] = nn.Linear(dim, ATTN_DIM)

        self.modality_gate = nn.Sequential(
            nn.Linear(ATTN_DIM, 1),
            nn.Sigmoid()
        )

        self.classifier = nn.Sequential(
            nn.LayerNorm(ATTN_DIM),
            nn.Linear(ATTN_DIM, NUM_CLASSES)
        )

    def forward(self, x, modalities):
        B = x.size(0)
        fused = torch.zeros(B, ATTN_DIM, device=x.device)

        for mod in MODALITIES:
            idx = [i for i, m in enumerate(modalities) if m == mod]
            if not idx:
                continue

            feat = self.backbones[mod](x[idx]).flatten(1)
            feat = self.projections[mod](feat)
            gate = self.modality_gate(feat)
            fused[idx] += feat * gate

        return self.classifier(fused)

# Training Function

def train(model, train_dl, val_dl):
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = torch.amp.GradScaler()

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    best_acc, wait = 0.0, 0

    for ep in range(1, EPOCHS + 1):
        model.train()
        preds, gts, run_loss = [], [], 0.0

        for x, y, m in train_dl:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()

            with torch.amp.autocast(device_type="cuda", enabled=USE_AMP):
                out = model(x, m)
                loss = criterion(out, y)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            run_loss += loss.item()
            preds += out.argmax(1).cpu().tolist()
            gts += y.cpu().tolist()

        train_acc = accuracy_score(gts, preds) * 100
        train_loss = run_loss / len(train_dl)

        model.eval()
        preds, gts, run_loss = [], [], 0.0
        with torch.no_grad():
            for x, y, m in val_dl:
                x, y = x.to(DEVICE), y.to(DEVICE)
                with torch.amp.autocast(device_type="cuda", enabled=USE_AMP):
                    out = model(x, m)
                    loss = criterion(out, y)

                run_loss += loss.item()
                preds += out.argmax(1).cpu().tolist()
                gts += y.cpu().tolist()

        val_acc = accuracy_score(gts, preds) * 100
        val_loss = run_loss / len(val_dl)
        scheduler.step()

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(f"Epoch {ep}/{EPOCHS} | Train Acc {train_acc:.2f}% | Val Acc {val_acc:.2f}%")

        if val_acc > best_acc:
            best_acc = val_acc
            wait = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            wait += 1
            if wait >= PATIENCE:
                print(f"\nEarly stopping | Best Val Acc = {best_acc:.2f}%")
                break

    return history

def plot_training_curves(history):
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(epochs, history["train_loss"], label="Train")
    plt.plot(epochs, history["val_loss"], label="Val")
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(epochs, history["train_acc"], label="Train")
    plt.plot(epochs, history["val_acc"], label="Val")
    plt.title("Accuracy Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()

    plt.tight_layout()
    plt.show()

# Evaluation Function

def evaluate(model, val_dl):
    model.eval()
    preds, labels = [], []

    with torch.no_grad():
        for x, y, m in val_dl:
            out = model(x.to(DEVICE), m)
            preds += out.argmax(1).cpu().tolist()
            labels += y.tolist()

    cm = confusion_matrix(labels, preds)
    sns.heatmap(cm, annot=True, fmt="d",
                xticklabels=["Benign/Normal", "Malignant"],
                yticklabels=["Benign/Normal", "Malignant"])
    plt.show()

    print(classification_report(labels, preds))

# Main Function

def main():
    tf = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],
                             [0.229,0.224,0.225])
    ])

    dataset = MSIDataset(DATA_ROOT, MODALITIES, tf)
    t = int(0.8 * len(dataset))
    v = len(dataset) - t
    train_set, val_set = random_split(dataset, [t, v])

    collate_fn = lambda b: (
        torch.stack([x[0] for x in b]),
        torch.tensor([x[1] for x in b]),
        [x[2] for x in b]
    )

    train_dl = DataLoader(train_set, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_dl = DataLoader(val_set, BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    model = MultimodalFusionNet().to(DEVICE)
    history = train(model, train_dl, val_dl)
    plot_training_curves(history)
    evaluate(model, val_dl)

if __name__ == "__main__":
    main()
