In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    roc_curve,
    auc,
    classification_report
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
EPOCHS = 30
LR = 1e-4
PATIENCE = 5

BASE_DIR = "/kaggle/input/breast-cancer-msi-multimodal-image-dataset/MultiModel Breast Cancer MSI Dataset"

MODALITIES = [
    "Chest_XRay_MSI",
    "Histopathological_MSI",
    "Ultrasound Images_MSI"
]

LABEL_MAP = {
    "benign": 0, "Benign": 0,
    "normal": 0, "Normal": 0,
    "malignant": 1, "Malignant": 1
}

class MSI_Dataset(Dataset):
    def __init__(self, base_dir, modalities, transform=None):
        self.samples = []
        self.transform = transform

        for mod in modalities:
            mod_path = os.path.join(base_dir, mod)
            if not os.path.isdir(mod_path):
                print(f"WARNING: missing folder {mod_path}")
                continue

            for cls in os.listdir(mod_path):
                cls_path = os.path.join(mod_path, cls)
                if not os.path.isdir(cls_path):
                    continue

                label = LABEL_MAP.get(cls, None)
                if label is None:
                    continue

                for fname in os.listdir(cls_path):
                    fpath = os.path.join(cls_path, fname)
                    if os.path.isfile(fpath):
                        self.samples.append((fpath, label, mod))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label, modality = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label, modality

# ResNet-50 as Backbone

def get_resnet50_backbone():
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    return nn.Sequential(*list(model.children())[:-1])  # (B,2048,1,1)

# Multimodal Model

class MultiModalResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_chest = get_resnet50_backbone()
        self.enc_hist  = get_resnet50_backbone()
        self.enc_ultra = get_resnet50_backbone()

        self.classifier = nn.Sequential(
            nn.Linear(2048 * 3, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )

    def forward(self, imgs, modalities):
        B = imgs.size(0)
        device = imgs.device

        chest_idx = [i for i,m in enumerate(modalities) if m == "Chest_XRay_MSI"]
        hist_idx  = [i for i,m in enumerate(modalities) if m == "Histopathological_MSI"]
        ultra_idx = [i for i,m in enumerate(modalities) if m == "Ultrasound Images_MSI"]

        feats = [
            torch.zeros((B,2048), device=device),
            torch.zeros((B,2048), device=device),
            torch.zeros((B,2048), device=device)
        ]

        if chest_idx:
            f = self.enc_chest(imgs[chest_idx]).view(len(chest_idx), -1)
            for j,k in enumerate(chest_idx): feats[0][k] = f[j]

        if hist_idx:
            f = self.enc_hist(imgs[hist_idx]).view(len(hist_idx), -1)
            for j,k in enumerate(hist_idx): feats[1][k] = f[j]

        if ultra_idx:
            f = self.enc_ultra(imgs[ultra_idx]).view(len(ultra_idx), -1)
            for j,k in enumerate(ultra_idx): feats[2][k] = f[j]

        fused = torch.cat(feats, dim=1)
        return self.classifier(fused)


# Image Transformation

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

# Training Function

def train_model(model, train_loader, val_loader):
    model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    best_val = 0
    wait = 0

    for ep in range(EPOCHS):
        model.train()
        preds, gts = [], []

        for imgs, labels, mods in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            out = model(imgs, mods)
            loss = criterion(out, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds += out.argmax(1).cpu().tolist()
            gts += labels.cpu().tolist()

        train_acc = accuracy_score(gts, preds)

        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for imgs, labels, mods in val_loader:
                out = model(imgs.to(DEVICE), mods)
                preds += out.argmax(1).cpu().tolist()
                gts += labels.tolist()

        val_acc = accuracy_score(gts, preds)
        print(f"Epoch {ep+1}/{EPOCHS} | TrainAcc={train_acc:.4f} | ValAcc={val_acc:.4f}")

        if val_acc > best_val:
            best_val = val_acc
            wait = 0
            torch.save(model.state_dict(), "best_model_resnet50.pth")
        else:
            wait += 1
            if wait >= PATIENCE:
                print("Early stopping triggered")
                break

    return best_val

# Evaluation

def evaluate(model, loader):
    model.eval()
    preds, labels, probs = [], [], []

    with torch.no_grad():
        for imgs, y, mods in loader:
            out = model(imgs.to(DEVICE), mods)
            p = torch.softmax(out, dim=1)[:, 1]

            preds += out.argmax(1).cpu().tolist()
            labels += y.tolist()
            probs += p.cpu().tolist()

    cm = confusion_matrix(labels, preds)
    sns.heatmap(cm, annot=True, fmt="d",
                xticklabels=["Benign/Normal", "Malignant"],
                yticklabels=["Benign/Normal", "Malignant"])
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.show()

    print(classification_report(labels, preds,
          target_names=["Benign/Normal", "Malignant"]))

    fpr, tpr, _ = roc_curve(labels, probs)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.4f})")
    plt.plot([0,1],[0,1],'--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROCâ€“AUC Curve")
    plt.legend()
    plt.show()

# Main Function 

dataset = MSI_Dataset(BASE_DIR, MODALITIES, transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

model = MultiModalResNet50()
best_acc = train_model(model, train_loader, val_loader)

print("\nTraining Completed | Best Validation Accuracy:", best_acc)

model.load_state_dict(torch.load("best_model_resnet50.pth"))
evaluate(model, val_loader)
