# Navigating the Multimodal Map: Insights into Foundation Models

## Day 2: Multimodal Data Integration and Visualization 

#### Author: Sylwia Majchrowska

#### Course instructors: Sylwia Majchrowska and Ricardo Mokhtari, Centre for AI, DS&AI, BioPharma R&D, AstraZeneca.

![course](https://drive.google.com/uc?export=view&id=1RnL_dsqRNKY2zqd2aaD9Q20yjknL2E4T)

## Description of the hands on session

This hands-on workshop delves into the intricacies of multi-modal machine learning, focusing on the integration of data from various sources.

Similar to human physicians, automated detection and classification systems that use both medical imaging data and clinical textual data. The integration of diverse data types—spanning clinical records and medical imaging—into a cohesive analytical framework represents a significant leap forward in the development of automated detection and classification systems. By embracing a fusion paradigm, these systems can achieve a level of performance and insight that more closely aligns with the nuanced, holistic approach of human medical practitioners.

The workshop is structured into three key segments:

**1. Multimodal image registration - alignment of image modalities**

This section focuses on the alignment and harmonization of different image modalities, enabling a unified and coherent analysis of multi-modal imaging data.

**2. Data integration strategies - multimodal data fusion**

We will examine techniques for integrating and fusing data from diverse modalities, emphasizing the seamless amalgamation of information for enhanced insights.

**3. Multimodal Data Fusion Using Embeddings**

We will engage in a practical exercise aimed at combining image and text data, gaining hands-on experience in leveraging multi-modal information.

## Navigating this notebook

*This notebook is designed to be self-contained. Therefore, there are long descriptions in the markdown cells. You do not need to read all of the information here during the hadns on session. Instead, focus on running/modifying the code, and after the session detailed explanations are provided for you to go deeper into the topics.*

___

# Part 1: Multimodal image registration - alignment of image modalities

As multimodality can be derived from a variety of sources; this part discusses image-image learning.

## Multimodal Imaging Studies

Multimodal imaging studies necessitate the co-registration of images, a critical process that involves geometrically aligning two or more images to ensure their corresponding pixels (or voxels) accurately represent the same anatomical structures. This alignment is pivotal for the integrity of subsequent quantitative image analyses. Image co-registration can be broadly classified into two categories based on the reference framework used:
- Atlas-based Registration: This approach involves aligning images to a pre-existing anatomical atlas, serving as a universal reference.
- Image-based Registration: Alternatively, this method selects one image from the set as the reference, aligning all other images to it.

The co-registration process is indispensable for ensuring accurate and meaningful analysis of multimodal imaging data. It lays the foundation for any further quantitative analysis by aligning disparate imaging modalities at a precise anatomical level. A variety of techniques are available for image registration, ranging from traditional methods to advanced deep learning (DL)-based approaches. For an in-depth exploration of these techniques, the work by [Haskins, Kruger, and Yan 2020](https://arxiv.org/abs/1903.02026) provides a comprehensive overview. 

One of the quintessential applications of multimodal imaging is the fusion of structural and functional imaging modalities, such as PET-CT and PET-MR. These combinations are particularly beneficial in oncology, where molecular imaging's limited spatial resolution may not suffice to pinpoint tumor locations accurately. The integration of CT or MR imaging within the same session enhances the anatomical localization of molecular imaging findings, thereby significantly improving diagnostic accuracy and treatment planning.

## Brain Tumor Radiogenomic Classification

The goal is to develop a tool based on deep learning that automates the detection of MGMT promoter methylation in brain tumors using MRI scans.

Glioblastoma represents a formidable challenge in oncology due to its aggressive nature and poor prognosis, with median survival rates currently less than a year. A key genetic marker, MGMT promoter methylation, has emerged as a significant prognostic factor. Its presence indicates a better response to chemotherapy, making its detection crucial for effective treatment planning.

Traditionally, identifying the genetic characteristics of a tumor, such as MGMT promoter methylation, requires invasive surgical procedures to obtain tissue samples. This process is not only risky but also time-consuming, with results taking weeks to be finalized. Moreover, the initial treatment approach may necessitate further surgeries based on the genetic findings, adding to the patient's burden. Radiogenomics offers an alternative by potentially allowing the prediction of tumor genetics through non-invasive imaging techniques alone. This approach could reduce the need for multiple surgeries and enable more tailored and effective treatment strategies, ultimately improving patient outcomes.

This part is based on the [RSNA-MICCAI Brain Tumor Radiogenomic Classification Kaggle challenge](https://www.kaggle.com/competitions/rsna-miccai-brain-tumor-radiogenomic-classification/overview). Challenge organizers provided the participants with MRI images from two subsets, training and test, categorized by patient. Each patient's data includes images from four different scan types: FLAIR, T1w, T1wCE, and T2. This challenge underscores the potential of deep learning in bridging the gap between imaging and genetic analysis, offering hope for more personalized and effective brain cancer treatment strategies.

In [None]:
# first, we need to install and import libraies
# !pip install pydicom efficientnet-pytorch -q

# imports
import os
import glob
import random

import matplotlib.pyplot as plt
from matplotlib import animation, rc
rc('animation', html='jshtml')
%matplotlib notebook


import numpy as np
import pandas as pd

import pydicom
import cv2

import torch
from torch import nn
from torch.utils import data as torch_data
from torch.nn import functional as F

import efficientnet_pytorch

from torch.utils.data import Dataset, DataLoader

In [None]:
#  setup training and model parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 123

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed)

#  Setup hyperparameters
class X3D:
    XS=0
    S=1
    M=2
    L=3

x3d_config = {
    'input_clip_length': [4, 13, 16, 16],
    'depth_factor': [2.2, 2.2, 2.2, 5.0],
    'width_factor': [1, 1, 1, 2.9]
}

class CFG:
    data_root = 'data'
    img_size = 256
    n_frames = 10

    cnn_features = 256
    lstm_hidden = 32

    n_fold = 5
    n_epochs = 10

## Understanding the Data

Each patient ID, represented by a unique identifier (e.g., “00000”), is associated with four distinct scans, derived from different pulse sequences. These scans are stored as `DICOM` files, which have a `.dcm` file extension. The DICOM standard not only facilitates the storage of the image itself but also embeds a rich set of metadata, including patient information, scan parameters.

### DICOM files
DICOM, an acronym for "Digital Imaging and Communications in Medicine," is the global standard for storing, viewing, retrieving, and sharing medical images. To work with DICOM files in a research or clinical setting, specialized DICOM viewer software is typically required. However, for computational analysis, the `PyDicom` library offers a powerful and flexible toolset for reading DICOM files directly within Python environments.

In [None]:
path = ''  # provide your path to selected dicom file
dicom = pydicom.read_file(path)
print(dicom)

In [None]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)

    data = np.float32(cv2.resize(data, (CFG.img_size, CFG.img_size)))
    return torch.tensor(data)

def visualize_sample(
    brats21id, 
    slice_i,
    types=("FLAIR", "T1w", "T1wCE", "T2w")
    ):
    _, axes = plt.subplots(ncols=len(types), figsize=(4*len(types), 5))
    patient_path = os.path.join(CFG.data_root, "test", str(brats21id).zfill(5))
    for i, t in enumerate(types):
        t_paths = sorted(
            glob.glob(os.path.join(patient_path, t, "*")), 
            key=lambda x: int(x[:-4].split("-")[-1]),
        )
        data = load_dicom(t_paths[int(len(t_paths) * slice_i)])
        axes[i].imshow(data, cmap="gray")
        axes[i].set_title(f"{t}", fontsize=16)
        axes[i].axis("off")

    plt.suptitle(f"Patient: {brats21id}", fontsize=16)
    plt.show()

def load_dicom_line(path):
    t_paths = sorted(
        glob.glob(os.path.join(path, "*")),
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    images = []
    for filename in t_paths:
        data = load_dicom(filename)
        if data.max() == 0:
            continue
        images.append(data)

    return images

def load_image(path):
    image = cv2.imread(path, 0)
    if image is None:
        return np.zeros((CFG.img_size, CFG.img_size))

    image = cv2.resize(image, (CFG.img_size, CFG.img_size)) / 255
    return torch.tensor(image)

def get_valid_frames(t_paths):
    res = []
    for path in t_paths:
        img = load_dicom(path)
        if img.view(-1).mean(0) != 0:
            res.append(path)
    return res

In [None]:
for i in [1, 13, 15]:
  visualize_sample(brats21id=i, slice_i=0.5)

In [None]:
def create_animation(ims):
    fig = plt.figure(figsize=(6, 6))
    plt.axis('off')
    im = plt.imshow(ims[0], cmap="gray")
    def animate_func(i):
        im.set_array(ims[i])
        return [im]

    return animation.FuncAnimation(fig, animate_func, frames = len(ims), interval = 1000//24)

from IPython.display import clear_output
animations = dict()
for t in ("FLAIR", "T1w", "T1wCE", "T2w"):
    images = load_dicom_line(os.path.join(CFG.data_root, 'test', '00001', t))
    animations[t] = create_animation(images)
clear_output()

animations['FLAIR']

### Generate fused MRI sequences

To manage the temporal dimension of MRI sequences effectively, a uniform temporal subsampling strategy is employed. For a given MRI sequence with a total of T frames, a subset of 10 frames is selected to represent the sequence. This selection is made using uniform intervals to ensure a broad coverage of the entire sequence. For instance, in a video with 91 frames, frames numbered 1, 11, 21, ..., 91 are chosen. This approach ensures that the selected frames are evenly distributed, capturing the essential temporal dynamics of the MRI sequence.

Once the frames are selected, the next step involves the fusion of different MRI image types (e.g., FLAIR, T1w, T1wCE, T2) for each time frame. Specifically, 4 single-channel MRI images corresponding to the different types are concatenated to form a single 4-channel feature image. This multi-channel image serves as a comprehensive representation, incorporating various aspects of the brain's anatomy and pathology as captured by the different MRI modalities. It's not uncommon for certain MRI image types to be missing for a patient. In such cases, the missing image channels are filled with zeros. This approach ensures that the resulting 4-channel feature image maintains a consistent format across all patients, facilitating the training and inference phases of the deep learning model. By filling missing channels with zeros, the model can recognize and account for the absence of certain image types without disrupting the analysis process.

In [None]:
def uniform_temporal_subsample(x, num_samples):
    '''
        Args:
            x: input list
            num_samples: The number of equispaced samples to be selected
        Returns:
            Output list
    '''
    t = len(x)
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
    return [x[i] for i in indices]

class TestDataRetriever(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def read_video(self, vid_paths):
        video = [load_dicom(path) for path in vid_paths]
        if len(video)==0:
            video = torch.zeros(CFG.n_frames, CFG.img_size, CFG.img_size)
        else:
            video = torch.stack(video) # T * C * H * W
        return video

    def __getitem__(self, index):
        _id = self.paths[index]
        patient_path = os.path.join(CFG.data_root, "test", str(_id).zfill(5))
        channels = []
        for t in ["FLAIR","T1w", "T1wCE", "T2w"]:
            t_paths = sorted(
                glob.glob(os.path.join(patient_path, t, "*")),
                key=lambda x: int(x[:-4].split("-")[-1]),
            )
            num_samples = CFG.n_frames
            if len(t_paths) < num_samples:
                in_frames_path = t_paths
            else:
                in_frames_path = uniform_temporal_subsample(t_paths, num_samples)

            channel = self.read_video(in_frames_path)
            if channel.shape[0] == 0:
                print("1 channel empty")
                channel = torch.zeros(num_samples, CFG.img_size, CFG.img_size)
            channels.append(channel)

        channels = torch.stack(channels).transpose(0,1)
        return {"X": channels.float(), "id": _id}

## CNN-LSTM architecture

The architecture integrates a CNN model for the extraction of image features, specifically utilizing a pre-trained EfficientNet B0 model. Given the unique requirement of processing 4-channel MRI images, a 2D convolution layer is employed to transform these 4-channel inputs into a 3-channel feature map. This adaptation ensures compatibility with the EfficientNet model's input specifications. Additionally, the original classification head of the EfficientNet model is substituted with a custom fully-connected layer, designed with 256 nodes to tailor the model's output to the specific needs of the task.

Following the feature extraction phase, the model leverages LSTM layers to analyze temporal dependencies within the data. Specifically, embeddings from 10 carefully selected frames undergo processing by two LSTM layers, each configured with a hidden size of 32. This sequential processing culminates in a prediction layer, consisting of a single node, which outputs a score indicative of the model's inference.

The training regimen for this architecture spans 15 epochs, employing a binary cross-entropy loss function to quantify the model's performance. Optimization is facilitated through the use of the Adam optimizer, set with a learning rate of 1e-4.

In [None]:
class CNN(nn.Module):
    def __init__(self, checkpoint_path):
        super().__init__()
        self.map = nn.Conv2d(in_channels=4, out_channels=3, kernel_size=1)
        self.net = efficientnet_pytorch.EfficientNet.from_name("efficientnet-b0")
        if checkpoint_path:
            checkpoint = torch.load(checkpoint_path)
            self.net.load_state_dict(checkpoint)
        
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=CFG.cnn_features, bias=True)
    
    def forward(self, x):
        x = F.relu(self.map(x))
        out = self.net(x)
        return out

class Model(nn.Module):
    def __init__(self, cnn_path=None):
        super().__init__()
        self.cnn = CNN(cnn_path)
        self.rnn = nn.LSTM(CFG.cnn_features, CFG.lstm_hidden, 2, batch_first=True)
        self.fc = nn.Linear(CFG.lstm_hidden, 1, bias=True)

    def forward(self, x):
        # x shape: BxTxCxHxW
        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.cnn(c_in)
        r_in = c_out.view(batch_size, timesteps, -1)
        output, (hn, cn) = self.rnn(r_in)
        
        out = self.fc(hn[-1])
        return out

### Cross validation
Stratified K-fold cross validation is used with K=5 on MGMT value. During inference time, the mean prediction value from all 5 models is used as the ensemble’s prediction value.

The model weights can be found [here](https://www.kaggle.com/code/minhnhatphan/rnsa-21-cnn-lstm-train/output).

In [None]:
models = []
for i in range(1, CFG.n_fold+1):
    model = Model()
    model.to(device)
    checkpoint = torch.load(f"models/best-model-{i}.pth")
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    models.append(model)

In [None]:
test_data_retriever = TestDataRetriever(
    [   1,   13,   15,   27,   37,   47,   79,   80,   82,   91,  114,
        119,  125,  129,  135,  145,  153,  161,  163,  174,  181,  182,
        190,  200,  208,  213,  229,  252,  256,  264,  287,  307,  323,
        333,  335,  337,  355,  372,  381,  384,  393,  422,  428,  434,
        438,  447,  450,  458,  460,  462,  463,  467,  474,  489,  492,
        503,  521,  535,  553,  560,  573,  585,  592,  595,  603,  644,
        647,  662,  671,  681,  699,  702,  712,  719,  721,  749,  762,
        769,  779,  821,  822,  825,  826,  829,  833,  997, 1006]
)

test_loader = torch_data.DataLoader(
    test_data_retriever,
    batch_size=4,
    shuffle=False,
    num_workers=8,
)

In [None]:
y_pred = []
ids = []

for e, batch in enumerate(test_loader):
    print(f"{e}/{len(test_loader)}", end="\r")
    with torch.no_grad():
        tmp_pred = np.zeros((batch["X"].shape[0], ))
        for model in models:
            tmp_res = torch.sigmoid(model(batch["X"].to(device))).cpu().numpy().squeeze()
            tmp_pred += tmp_res

        tmp_pred = tmp_pred/len(models)
        y_pred.extend(tmp_pred)
        ids.extend(batch["id"].numpy().tolist())

In [None]:
prediction = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred})
prediction.head()

# Part 2: Data integration strategies - multimodal data fusion

In part 2 of this notebook, we are going to learn about different fusion techniques.

### What is Multimodal Data Fusion?

Multimodal Data Fusion is a sophisticated technique that integrates data from diverse sources or modalities, such as text, images, and sensor data, to enhance the decision-making capabilities of machine learning models. This approach capitalizes on the complementary strengths of different data types, offering a more holistic understanding of the subject matter, which in turn improves the accuracy and robustness of predictive models.

### Significance in Machine Learning

The fusion of multimodal data enables machine learning models to access a broader and richer feature set than what unimodal data could provide. This enriched data context leads to more nuanced insights and predictions across various applications. For instance, in healthcare, combining clinical notes, lab results, and medical imaging can offer a comprehensive patient overview, aiding in more accurate diagnoses and treatment plans. Similarly, autonomous vehicles rely on a blend of visual, radar, and lidar data to navigate and make decisions, showcasing the critical role of multimodal data fusion in enhancing operational safety and efficiency.

### Approaches to Multimodal Data Fusion

**Early Fusion:** This strategy involves merging features from different modalities at the data level before inputting them into the model. While straightforward, early fusion might not fully capture the intricate interactions between modalities.

**Late Fusion:** This approach trains separate models for each modality and combines their predictions at the decision level. Late fusion respects the uniqueness of each data type but may miss out on capturing deeper intermodal relationships.

**Hybrid Fusion:** A combination of early and late fusion techniques, hybrid fusion aims to leverage intermediate representations to capture both intra-modal and intermodal dynamics effectively.

# Part 3: Multimodal Data Fusion Using Embeddings

In this part we will be using [HAM10000 dataset](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DBW86T). It is a large collection of dermatoscopic images from different populations, acquired and stored by the Department of Dermatology at the Medical University of Vienna, Austria. It consists of 10,015 dermatoscopic images which can serve as a training set for academic machine learning purposes in tasks like skin lesion analysis and classification, specifically focusing on the detection of melanoma.

During the session we will be testing embeddings using early and late fusion approaches.
- Early Fusion Approach: Concatenate text and image embeddings, pass through a layer with 256 neurons, 0.2 dropout, ReLu Activation, and Batch Normalization. Connect to a classification layer.
- Late Fusion Approach: Process text and image embeddings with two layers each (128 neurons, ReLu activation, 0.2 dropout, and BatchNorm). Concatenate outputs and connect to a classification layer.

In [None]:
#!git clone https://github.com/dsrestrepo/Foundational-Multimodal-Fusion-Benchmark.git
#sys.path.append('Foundational-Multimodal-Fusion-Benchmark')

# imports

from src.classifiers import preprocess_data, process_labels,split_data

from src.classifiers import VQADataset
from torch.utils.data import DataLoader

from src.classifiers_cpu_metrics import train_early_fusion, train_late_fusion

In [None]:
# data handling
PATH = 'Embeddings/ham10000/'

text_path, images_path = os.listdir(PATH)

text = pd.read_csv(os.path.join(PATH, text_path))
print(text.head())

images = pd.read_csv(os.path.join(PATH, images_path))
(images.head())

# Merge and preprocess the datasets
df = preprocess_data(text, images, "image_id", "ImageName")
df.drop(columns='text', inplace=True)
df.to_csv(os.path.join(PATH, 'embeddings.csv'), index=False)
print(df.head())

In [None]:
# Data Perparation

# Split the data
train_df, test_df = split_data(df)

# Select features and labels vectors
text_columns = [column for column in df.columns if 'text' in column] #[f'text_{i}']
image_columns = [column for column in df.columns if 'image' in column] #[f'image_{i}']
label_columns = 'dx'


# Process and one-hot encode labels for training set
train_labels, mlb, train_columns = process_labels(train_df, col=label_columns)
test_labels = process_labels(test_df, col=label_columns, train_columns=train_columns)


train_dataset = VQADataset(train_df, text_columns, image_columns, label_columns, mlb, train_columns)
test_dataset = VQADataset(test_df, text_columns, image_columns, label_columns, mlb, train_columns)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

In [None]:
# models

text_input_size = len(text_columns)
image_input_size = len(image_columns)
if label_columns == 'DR_2':
    output_size = 1
else:
    output_size = len(pd.unique(train_df[label_columns]))
multilabel = False

In [None]:
# Train early fusion model
print("Training Early Fusion Model:")
train_early_fusion(train_loader, test_loader, text_input_size, image_input_size, output_size, num_epochs=30, multilabel=multilabel, report=True)

In [None]:
# Train late fusion model
print("Training Late Fusion Model:")
train_late_fusion(train_loader, test_loader, text_input_size, image_input_size, output_size, num_epochs=30, multilabel=multilabel, report=True)

## That's the end of the notebook!

Here is the list of extra resources:
1. [The RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor Segmentation and Radiogenomic Classification](https://arxiv.org/abs/2107.02314)
2. [Multimodal data fusion – analysis](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10007548/)
3. [Fusion of Multi-Modal Data Stream for Clinical Event Prediction - Imon Banerjee, PhD](https://www.youtube.com/watch?v=3DroMVNb2vg)
4. [Data-Efficient Multimodal Fusion on a Single GPU](https://arxiv.org/pdf/2312.10144.pdf)
5. [Integrated multimodal artificial intelligence framework for healthcare applications](https://www.nature.com/articles/s41746-022-00689-4)
6. [Inferring multimodal latent topics from electronic health records](https://www.nature.com/articles/s41467-020-16378-3)
7. [Multimodal Risk Prediction with Physiological Signals, Medical Images and Clinical Notes](https://www.medrxiv.org/content/10.1101/2023.05.18.23290207v1)