# Importing Libraries

In [None]:
import monai
import torch
import os
import pandas as pd
import torch.nn as nn
from monai.transforms import (
    Compose, LoadImage, AddChannel, ScaleIntensity, ToTensor, Resize)
from monai.data import Dataset, DataLoader
from monai.networks.nets import DenseNet121
from torch.utils.data import random_split

# 1. Data Preprocessing


In [None]:
vw_mri_dir = "/data/vw_mri/"
regular_mri_dir = "/data/regular_mri/"
ct_dir = "/data/ct_scans/"  

labels_file = "/data/labels.csv"
labels_df = pd.read_csv(labels_file)

vw_mri_paths = [os.path.join(vw_mri_dir, f"{row['filename']}") for _, row in labels_df.iterrows()]
regular_mri_paths = [os.path.join(regular_mri_dir, f"{row['filename']}") for _, row in labels_df.iterrows()]
ct_paths = [os.path.join(ct_dir, f"{row['filename']}") for _, row in labels_df.iterrows()]

labels = labels_df['label'].tolist()

data_dicts = [{"vw_mri": vw, "regular_mri": reg, "ct": ct, "label": label}
              for vw, reg, ct, label in zip(vw_mri_paths, regular_mri_paths, ct_paths, labels)]

transform_vw_mri_3d = Compose([
    LoadImage(image_only=True),
    AddChannel(),
    ScaleIntensity(),
    Resize(spatial_size=(128, 128, 64)),  
    ToTensor()
])

transform_2d = Compose([
    LoadImage(image_only=True),
    AddChannel(),
    ScaleIntensity(),
    Resize(spatial_size=(128, 128)),  
    ToTensor()
])

# Making Multimodal Dataset

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, data, transform_vw_mri_3d, transform_2d):
        self.data = data
        self.transform_vw_mri_3d = transform_vw_mri_3d
        self.transform_2d = transform_2d

    def __getitem__(self, index):
        item = self.data[index]
        vw_mri = self.transform_vw_mri_3d(item["vw_mri"])
        regular_mri = self.transform_2d(item["regular_mri"])
        ct = self.transform_2d(item["ct"])
        label = torch.tensor(item["label"], dtype=torch.long)
        return vw_mri, regular_mri, ct, label

    def __len__(self):
        return len(self.data)

dataset = MultimodalDataset(data_dicts, transform_vw_mri_3d, transform_2d)
train_set, val_set = random_split(dataset, [int(len(dataset) * 0.8), int(len(dataset) * 0.2)])
train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
val_loader = DataLoader(val_set, batch_size=2)

# Designing Multimodal Network for both Modalities

In [None]:
class MultimodalNetwork3D2D(nn.Module):
    def __init__(self):
        super(MultimodalNetwork3D2D, self).__init__()
        
        self.vw_mri_net_3d = DenseNet121(spatial_dims=3, in_channels=1, out_channels=2)
        
        self.regular_mri_net_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2)
        
        self.ct_net_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2)
        
        self.fc1 = nn.Linear(6, 512)  
        self.fc2 = nn.Linear(512, 256)  
        self.fc3 = nn.Linear(256, 128)  
        self.fc4 = nn.Linear(128, 2)    
        
        self.dropout = nn.Dropout(p=0.5)  

    def forward(self, vw_mri, regular_mri, ct):
        vw_out = self.vw_mri_net_3d(vw_mri)  
        regular_out = self.regular_mri_net_2d(regular_mri) 
        ct_out = self.ct_net_2d(ct)  
        
        combined = torch.cat((vw_out, regular_out, ct_out), dim=1)
        
        x = torch.relu(self.fc1(combined))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)  
        
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalNetwork3D2D().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training Model and Evaluating on Validation Set

In [None]:
def train_model(num_epochs, model, train_loader, val_loader, optimizer, loss_function, device):
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for vw_mri, regular_mri, ct, labels in train_loader:
            vw_mri, regular_mri, ct, labels = vw_mri.to(device), regular_mri.to(device), ct.to(device), labels.to(device)
            
            outputs = model(vw_mri, regular_mri, ct)
            loss = loss_function(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader)}")
        
        model.eval()
        val_loss = 0
        correct = 0
        with torch.no_grad():
            for vw_mri, regular_mri, ct, labels in val_loader:
                vw_mri, regular_mri, ct, labels = vw_mri.to(device), regular_mri.to(device), ct.to(device), labels.to(device)
                outputs = model(vw_mri, regular_mri, ct)
                loss = loss_function(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader)
        accuracy = correct / len(val_loader.dataset)
        print(f"Validation Loss: {val_loss}, Accuracy: {accuracy}")

train_model(20, model, train_loader, val_loader, optimizer, loss_function, device)