# COMP0188 Coursework 2: Age Regression from Brain MRIs (30 Marks)

MRI scans can be used to determine volumes of different types of brain tissue which are associated with age. In particular, as patients age, it is known that the ventricles enlarge (as they get filled with cerebrospinal fluid), while the volume of grey and white matter volume may decrease.

Your task is to develop a deep learning model capable of predicting the (biological) age of a patient from MRI scans. Such a tool could be used in clinical practice to compare a patients 'biological' age against their 'true' chronological age. A significant discrepency in these ages might indicate the presense of a disease in the patient. 

You have been provided with a dataset of healthy patients. The dataset contains MRI scans of the patients, and their corresponing chronological ages (amongst other information). As the patients are healthy, we will assume that their biological and chronological ages are equal.

We have provided you with helper code, and have marked additional code you will need to implement with "🚧 **Exercise** 🚧". However, you are not bound to this code (i.e. you may modify it if you wish), or even this notebook - you may complete the coursework however you see fit.

#### Notebook Overview
The notebook is split into 3 parts:

- Part 1: Dataset analysis and defining a suitable setup (8 marks)
- Part 2: Baseline model definition (7 marks)
- Part 3: Improving upon the Baseline (15 marks)

Please see the introduction for each section for more information.

#### Loading the Dataset and Running the Notebook

Use Google Colab to run the notebook. Run the cells in sequence, as per usual.

### Requirements
* SimpleITK
* wandb

### Dependancies

In [None]:
!pip install SimpleITK gdown wandb

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import wandb

from typing import List, Dict, Tuple, Literal
import pickle
import pandas as pd
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed
from IPython.display import display

import gdown
import zipfile
import os
import psutil

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def check_memory_used():
    process = os.getpid()
    proc = psutil.Process(process)
    memory_info = proc.memory_info()
    print(f"Memory Used: {memory_info.rss / (1024**3):.2f} GB")  # RSS (Resident Set Size)

### Global config

In [None]:
# Used for debugging the notebook locally. Leave as True when running in colab!
download = True

# Expect the download to take roughly 3 mins
if download:
    # Replace this with your Google Drive shared link or file ID
    google_drive_shared_link = 'https://drive.google.com/file/d/1OzIzn9tLUmq74I6YB5_nTnvaV_R1v7kj/view?usp=sharing'
    file_id = google_drive_shared_link.split('/')[-2]

    # Construct the gdown URL
    download_url = f'https://drive.google.com/uc?id={file_id}'

    # Path where the downloaded file will be stored
    output_path = 'download.zip'

    # Download the file from Google Drive
    gdown.download(download_url, output_path, quiet=False)

    # Unzip the file
    with zipfile.ZipFile(output_path, 'r') as zip_ref:
        zip_ref.extractall('data')

    # Optionally, delete the zip file after extraction
    # os.remove(output_path)
    
    data_dir = "data/coursework_1_compressed/coursework_1"
else:
    data_dir = "../../coursework_1_compressed/coursework_1" # "../../coursework_1/coursework_1"

In [None]:
MRI_TEMPLATE_FILE = os.path.join(data_dir, "quasiraw_space-MNI152_desc-brain_T1w.nii.gz")

In [None]:
TRAIN_DIR = os.path.join(data_dir, "train")
TRAIN_MRI_DIR = os.path.join(TRAIN_DIR, "quasiraw")
TRAIN_META_FILE = os.path.join(TRAIN_DIR, "participants.tsv")

In [None]:
TEST_DIR = os.path.join(data_dir, "test")
TEST_MRI_DIR = os.path.join(TEST_DIR, "quasiraw")
TEST_META_FILE = os.path.join(TEST_DIR, "participants.tsv")

In [None]:
OUT_SPACING = [2, 2, 2]
OUT_SIZE = [96, 96, 96]

## Part 1: Dataset analysis and define a suitable setup

### Dataset analysis

It is important to analyse your dataset to better understand it and to help detect any issues in the dataset. This can be done via visualizations and calculating statistics from the available information.

### Defining a suitable seup

Before performing any kind of model development, it is critical to define the scope of the model development process. This includes making decisions which stay fixed throughout the rest of the model development; as changing them would render model comparisons invalid. For example, comparing two models with test score calculated on different test datasets or with different metrics is meaningless.

This section should help you answer following questions:
* How should the train/validation/test set be split?
* What metric will be used to assess model performance on the test set?
    * It is critical to consider the broader project metric when setting this. In this case the project metric is to "predict the patients biological age from MRI scans". Furthermore, we want to make sure that the model can predict age well for _all_ patients, not just a subset (which may happen if the input data is skewed, for example).

_Hint_: 
* Consider the following kinds of analysis:
    * What relevant variables are available in the dataset? Do they need to be transformed?
    * What does the target variable look like?
    * Is the data sufficiently balanced?
    * What is the distribution of other variables?

### Dataset description
* The data provided has already been split into a training and test dataset. Both the training and test data contain: (i) a file of images in the form of numpy arrays (the MRIs); (ii) a tab seperated file called "participants.tsv" which contains structured data for each patient (including the overall target of interest "age").
* The "participant_id" column defines a unique ID for each patient and can be used to link the structured data with the MRI scans. In particular, the patient with participant_id = "100053248969" has an associated MRI scan in the file "sub-100053248969_preproc-quasiraw_T1w.npy"

The code below provides some helper functions for to import the relevant data for analysis

In [None]:
num_vars = ["age", "tiv", "csfv", "gmv", "wmv", "magnetic_field_strength", "acquisition_setting"]

# Load the training data
train_meta_df = pd.read_csv(TRAIN_META_FILE, delimiter='\t', dtype=str)
train_meta_df[num_vars] = train_meta_df[num_vars].astype(float)
train_meta_df["age_round"] = np.round(train_meta_df["age"],0).astype(int)
print(train_meta_df.shape)

bins = list(range(0, 80, 10))
bins.append(110)
train_meta_df["age_10_yr_bckt_bg_70"] = pd.cut(train_meta_df["age"], bins=bins)
print(train_meta_df["age_10_yr_bckt_bg_70"].value_counts())
train_meta_df["age_10_yr_bckt"] = pd.cut(train_meta_df["age"], bins=range(0, 110, 10))
print(train_meta_df["age_10_yr_bckt"].value_counts())

# Also load the test data - BUT DON'T LOOK AT IT!
test_meta_df = pd.read_csv(TEST_META_FILE, delimiter='\t', dtype=str)
test_meta_df[num_vars] = test_meta_df[num_vars].astype(float)
test_meta_df["age_round"] = np.round(train_meta_df["age"],0).astype(int)
print(test_meta_df.shape)

test_meta_df["age_10_yr_bckt_bg_70"] = pd.cut(test_meta_df["age"], bins=bins)
print(test_meta_df["age_10_yr_bckt_bg_70"].value_counts())
test_meta_df["age_10_yr_bckt"] = pd.cut(test_meta_df["age"], bins=range(0, 110, 10))
print(test_meta_df["age_10_yr_bckt"].value_counts())

train_meta_df.head()

Descriptions for relevant (non-self-descriptive) column names:
- **participant_id:** Unique ID for each patient, can be used to link to the MRI scans.
- **csfv:** Cerebrospinalfluid volume
- **gmv:** Grey matter volume
- **wmv:** White matter volume


In [None]:
def load_patient_mri_array(
    patient_id:str,
    mri_dir:str
    ) -> np.array:
    """Function to load a patient .nii.gz file containing an MRI 3D scan into a
    numpy array
    
    Args:
        patient_id (str): Patient ID string
        mri_dir (str): String representing the directory containing the MRI
        .nii.gz files
        
    Returns:
        np.array: 3D numpy array representing the MRI scan
    """
    mri_file = f"nifti/sub-{patient_id}_preproc-quasiraw_T1w.nii.gz"
    img_array = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(mri_dir, mri_file)))
    return np.expand_dims(img_array, axis=0)
    
def vis_raw_mri_image(
    img_array:np.array, 
    x:int=None, 
    y:int=None, 
    z:int=None, 
    crosshair:bool=False, 
    template_file:str=MRI_TEMPLATE_FILE
):
    """Function to display orthogonal 2D slices of the 3D MRI image

    Args:
        img_array (np.array): 3D numpy array representing the MRI scan
        x (int, optional): x slice co-ordinate. Defaults to None.
        y (int, optional): y slice co-ordinate. Defaults to None.
        z (int, optional): z slice co-ordinate. Defaults to None.
        crosshair (bool, optional): Flag that determines whether the images 
        should be shown with x/y lines across the centres of the axes. 
        Defaults to False.
        template_file (str, optional): String representing the MRI template file
        to extract the image dimensions. Defaults to MRI_TEMPLATE_FILE.
    """
    template = sitk.ReadImage(template_file)
    size = template.GetSize()
    spacing = template.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]
    
    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)
    
    # Display the orthogonal slices    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))
    
    ax1.imshow(img_array[z,:,:], extent=(0, width, height, 0))
    ax2.imshow(img_array[:,y,:], origin='lower', extent=(0, width,  0, depth))
    ax3.imshow(img_array[:,:,x], origin='lower', extent=(0, height, 0, depth))

    # Additionally display crosshairs
    if crosshair:
        ax1.axhline(y * spacing[1], lw=1)
        ax1.axvline(x * spacing[0], lw=1)
        ax2.axhline(z * spacing[2], lw=1)
        ax2.axvline(x * spacing[0], lw=1)
        ax3.axhline(z * spacing[2], lw=1)
        ax3.axvline(y * spacing[1], lw=1)
    
    plt.show()

Let's use the helper functions above to visualise the MRI scans for patient "100053248969"

In [None]:
img_array = load_patient_mri_array("100053248969", mri_dir=TRAIN_MRI_DIR)
vis_raw_mri_image(img_array.squeeze())
display(img_array.shape)

🚧 **Exercise 1.1** 🚧

Below, perform any data analysis / data visualizations you think will be useful. These may help you better understand the distribution and demographics of the dataset.

In [None]:
####################
# Insert code here #
####################

🚧 **Exercise 1.2** 🚧

In the code block below, add: 
* ```train_prop```: The proportion of the train data that you will assign to training the model (where the remaining will be used for validation)
* ```stratification_variables```: Whether you intend to stratify by any variables when creating the train/validation split
* ```test_metric```: The metric you intend to use to assess performance on the test set

(Stratified sampling is a technique used to ensure that the subsets of the data (in this case, training, validation, and test sets) are representative of the whole dataset. This is especially important in cases where the dataset is not homogeneous and contains distinct groups that should be evenly represented in each set.)

In [None]:
# Exercise: input answer here
# Your code here
train_prop = 
stratification_variables = # a list of variables from the meta data, e.g. ["variable_1", "variable_2"]. Ensure you choose at least one variable, or the code will not run!
test_metric =  # This is not used later in the notebook, but it is good to decide it now
# Your code here - END

The code below defines three numpy arrays ```train_pats```, ```val_pats``` and ```test_pats``` containing a (potentially stratified) random selection of patient ids assigned to the respective datasets.

In [None]:
grp_df = train_meta_df.groupby(by=stratification_variables)
train_pats = []
val_pats = []
test_pats = []
for idx, grp in grp_df:
    train, val = np.split(
        grp["participant_id"], 
        [
            int(np.floor(grp.shape[0]*train_prop))
        ]
    )
    train_pats.append(train) 
    val_pats.append(val)
train_pats = np.concatenate(train_pats)
val_pats = np.concatenate(val_pats)
test_pats = test_meta_df["participant_id"].values

with open(os.path.join(TRAIN_DIR, "train_pats.pkl"), "wb") as f:
    pickle.dump(train_pats, f)

with open(os.path.join(TRAIN_DIR, "val_pats.pkl"), "wb") as f:
    pickle.dump(val_pats, f)

with open(os.path.join(TEST_DIR, "test_pats.pkl"), "wb") as f:
    pickle.dump(test_pats, f)

print(f"Num training patients: {len(train_pats)}")
print(f"Num validation patients: {len(val_pats)}")
print(f"Num testing patients: {len(test_pats)}")

## Part 2: Baseline model definition

Now the model development assumptions have been defined, a baseline model needs to be developed. This baseline model should act as a spring board for your subsequent model development. Ideally should balance the following objectives:
* Attain a __reasonable__ level of performance, and;
* Be __simple__! This is almost the most important requirement since if the model is too complex, understanding which areas of the archtecture are underperforming will be challenging!


In this exercise the baseline model should be defined by the following pipeline:

1. **Volume prediction:** A deep learning model is trained to take in patient MRI images and predict three patient brain volume values (explained in more detail below).
2. **MLP Regression:** A multi-layer perceptron regression model takes patient brain volume values as input and predicts patient age.

These models will be trained seperately, then combined to make end-to-end MRI -> age predictions at test time. We elaborate more on the motivation behind this pipeline below.

__Your main task__ will be to define the deep learning Volume Prediction model and subsequently justify and explain your design choices. 

_Hints_:
* What might be a 'simple' go-to deep learning architecture for processing images?
* Do 3D images need to be treated differently to 2D images (see Pytorch [CONV3D](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#conv3d)...)
* How might 'reasonable' be defined?
    * Does the baseline perform better than random chance?
    * Does the baseline perform better than just applying a linear regression or xgBoost model directly on the images?
* As this is a baseline, performance is not expected to be excellent. As you are only expected to implement the volume predictor, it is ok if other parts of the pipline are limiting. However, you may make small modifications to the mlp training, etc. to ensure reasonable performance.

### Data Helpers

Before we begin, lets load and process the data.

In [None]:
def zero_mean_unit_var(
    img_array:np.array
    )->np.array:
    """Function to normalise an input image to have 0 mean and unit variance

    Args:
        img_array (np.array): 3D numpy array representing the MRI scan

    Returns:
        np.array: Normalised version of img_array
    """
    mean = np.mean(img_array)
    std = np.std(img_array)
    # Capture 0 values as these are background
    zero_values = img_array == 0
    if std > 0:
        img_array = (img_array - mean) / std
        img_array[zero_values] = 0
    return img_array


def resample_image(
    img_array:np.array, 
    out_spacing:Tuple[float]=(1.0, 1.0, 1.0), 
    out_size:Tuple[float]=None, 
    is_label:bool=False, 
    pad_value=0
    )->np.array:
    """Function to alter the proportions of an input image represented as a 
    numpy array  

    Args:
        img_array (np.array): 3D numpy array representing the MRI scan
        out_spacing (Tuple[float], optional): ???. Defaults to (1.0, 1.0, 1.0).
        out_size (Tuple[float], optional): Tuple of length 3 defining the 
        desired output dimensions of the image. Defaults to None.
        is_label (bool, optional): ???. Defaults to False.
        pad_value (int, optional): ???. Defaults to 0.

    Returns:
        np.array: _description_
    """
    
    image = sitk.GetImageFromArray(img_array)
    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return sitk.GetArrayFromImage(resample.Execute(image))

We will create seperate dataloaders for the volume prediction task and the linear regression task.

First, we create the volume prediction data loaders. This data loader provides a batch of 3D MRI images as 'X', and the corresponding brain volume labels as 'y'. The train/val/test splits you decided previously will be used.

In [None]:
class VolumePredictionDataset(Dataset):
    '''
    Important: By default, this dataset returns normalized versions of the input MRI images and output volume labels.
    '''
    
    def __init__(
        self, 
        patient_list:List[str], 
        mri_file_dir:str,
        meta_df:pd.DataFrame,
        volume_norm_stats:Dict[str,np.array],
        out_spacing = OUT_SPACING, 
        out_size = OUT_SIZE,
    ):
        self.samples:List[Dict[str,torch.tensor]] = []
        
        for pat in tqdm(patient_list, desc='Loading Data'):
            
            # MRI image
            X = load_patient_mri_array(pat, mri_dir=mri_file_dir).squeeze()
            X = zero_mean_unit_var(X)
            if (out_spacing is not None) and (out_size is not None):
                X = resample_image(X, out_spacing=out_spacing, out_size=out_size)
            X = torch.from_numpy(X).unsqueeze(0).float()
            
            # Volume labels
            y = meta_df[meta_df["participant_id"] == pat][["csfv", "gmv", "wmv"]].values.squeeze()
            y = (y-volume_norm_stats['mean'])/volume_norm_stats['std']
            y = torch.from_numpy(y).float()
            
            sample = {'X': X,  "y":y}
            self.samples.append(sample)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        return self.samples[item]

**Data normalization:**

In our datasets, we will normalize both the input and the output data for model training. Images are simply normalized to each have zero mean and unit variance. Volumes and ages are normalized using statistics of the train dataset, calculated below.

In [None]:
# Normalization stats

age_norm_stats = {
    'mean': train_meta_df["age"].mean(axis=0),
    'std': train_meta_df["age"].std(axis=0)
}

volume_norm_stats = {
    'mean': train_meta_df[["csfv", "gmv", "wmv"]].mean(axis=0).values,
    'std': train_meta_df[["csfv", "gmv", "wmv"]].std(axis=0).values
}

display(age_norm_stats)
display(volume_norm_stats)

In [None]:
# Expect this cell to take 8 mins

batch_size = 8 # Can be changed if you wish

vol_train_data = VolumePredictionDataset(
    patient_list=train_pats, 
    mri_file_dir=TRAIN_MRI_DIR,
    meta_df=train_meta_df,
    volume_norm_stats=volume_norm_stats,
)

vol_train_loader = DataLoader(vol_train_data, batch_size=batch_size)

vol_val_data = VolumePredictionDataset(
    patient_list=val_pats, 
    mri_file_dir=TRAIN_MRI_DIR,
    meta_df=train_meta_df,
    volume_norm_stats=volume_norm_stats,
)

vol_val_loader = DataLoader(vol_val_data, batch_size=batch_size)

vol_test_data = VolumePredictionDataset(
    patient_list=test_pats, 
    mri_file_dir=TEST_MRI_DIR,
    meta_df=test_meta_df,
    volume_norm_stats=volume_norm_stats,
)

tmp = next(vol_train_loader.__iter__())
print(f"Dataloader has output type {type(tmp)} with keys {tmp.keys()}")
print(f"The input (MRI images) dimensions are: {tmp['X'].shape}")
print(f"The output (csfv, gmv, wmv brain volumes) dimensions are: {tmp['y'].shape}")

Ensure you understand what the inputs and outputs represent, and what each axis of the data represents.

Next, we define the data loaders for the MLP age regression task. These return brain volumes as 'X' and age labels as 'y'.

In [None]:
class AgeRegressionDataset(Dataset):
    '''
    Important:
    - By default, this dataset returns normalized versions of the input volumes and output age labels.
    '''
    
    def __init__(
        self, 
        patient_list:List[str],
        meta_df:pd.DataFrame,
        volume_norm_stats:Dict[str,np.array],
        age_norm_stats:Dict[str,np.array],
    ):  
        self.samples:List[Dict[str,torch.tensor]] = []
        
        for pat in tqdm(patient_list, desc='Loading Data'):
            
            # Brain volumes
            X = meta_df[meta_df["participant_id"] == pat][["csfv", "gmv", "wmv"]].values.squeeze()
            X = (X-volume_norm_stats['mean'])/volume_norm_stats['std']
            X = torch.from_numpy(X).float()
            
            # Ages
            y = meta_df[meta_df["participant_id"] == pat][["age"]].values.squeeze()
            y = (y-age_norm_stats['mean'])/age_norm_stats['std']
            y = torch.tensor(y).float()
            
            sample = {'X': X,  "y":y}
            self.samples.append(sample)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        return self.samples[item]

In [None]:
rnd_seed = 42
batch_size = 2

mlp_train_data = AgeRegressionDataset(
    patient_list=train_pats,
    meta_df=train_meta_df,
    volume_norm_stats=volume_norm_stats,
    age_norm_stats=age_norm_stats,
)
mlp_train_loader = DataLoader(mlp_train_data, batch_size=batch_size)

mlp_val_data = AgeRegressionDataset(
    patient_list=val_pats,
    meta_df=train_meta_df,
    volume_norm_stats=volume_norm_stats,
    age_norm_stats=age_norm_stats,
)
mlp_val_loader = DataLoader(mlp_val_data, batch_size=batch_size)

tmp = next(mlp_train_loader.__iter__())
print(f"Dataloader has output type {type(tmp)} with keys {tmp.keys()}")
print(f"The input (csfv, gmv, wmv brain volumes) dimensions are: {tmp['X'].shape}")
print(f"The output (patient age) dimensions are: {tmp['y'].shape}")

Again, ensure you understand what the inputs and outputs represent, and what each axis of the data represents.

### Motivating the baseline model pipeline

Our training dataset provides ground-truth information about patient brain volumes, namely:
- csfv: Cerebrospinal fluid volume
- gmv: Grey matter volume
- wmv: White matter volume

It has been found previously that these brain volumes are correlated with age. Lets investigate if that is the case in our dataset. If so, these values could be useful for our task of predicting brain age!


In [None]:
# Get ages
ages = train_meta_df["age"].values

# Get volumes
volumes = train_meta_df[["csfv", "gmv", "wmv"]].values
csfv = volumes[:,0]
gmv = volumes[:,1]
wmv = volumes[:,2]

In [None]:
# Create scatter plots to visualize the correlation between volumes and age

import matplotlib.pyplot as plt

# Create a figure and a 1x3 subplot (for 3 plots in a row)
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))

# Plot csfv vs age
axes[0].scatter(ages, csfv, color='r', alpha=0.6)
axes[0].set_title('CSFV vs Age')
axes[0].set_xlabel('Age')
axes[0].set_ylabel('CSFV')

# Plot gmv vs age
axes[1].scatter(ages, gmv, color='g', alpha=0.6)
axes[1].set_title('GMV vs Age')
axes[1].set_xlabel('Age')
axes[1].set_ylabel('GMV')

# Plot wmv vs age
axes[2].scatter(ages, wmv, color='b', alpha=0.6)
axes[2].set_title('WMV vs Age')
axes[2].set_xlabel('Age')
axes[2].set_ylabel('WMV')

# Display the plots
plt.tight_layout()
plt.show()

We see some visual evidence here that age is correlated with the brain volumes. This means the brain volume information may be useful for predicting age. 

Thus, we may be able to split our baseline model into two parts: (i) Volume Prediction Model: a deep learning model that predicts brain volumes from MRI images; and (ii) MLP Regression Model: a multi-layer perceptron regression model that predicts age from brain volumes.

Note, this approach leverages the brain volume information available to in the training data, but will not require any ground-truth brain volume information at deplyoment! (At training, the Linear Regression Model will be trained to map from ground-truth volumes to age. But at deployment it will take volume values predicted by the Volume Prediction Model).

### Baseline: MLP Regression training

First lets train the MLP regression model.

Lets define some standard training and testing functions.

(**Optional Exercise**: You may modify these training and testing functions if you wish to add more sophisticated logging)

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    num_batches = len(dataloader)
    model.train()
    
    total_loss = 0
    for step, data in enumerate(dataloader):
        X, y = data['X'].to(device), data['y'].to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    average_loss = total_loss / num_batches
    print(f"Train loss: {average_loss:>7f}")

In [None]:
def test(dataloader, model, loss_fn, y_norm_stats):
    def denorm(y):
        return y.cpu() * y_norm_stats['std'] + y_norm_stats['mean']
    
    num_batches = len(dataloader)
    model.eval()
    
    test_loss, mae_loss = 0, 0
    with torch.no_grad():
        for data in dataloader:
            X, y = data['X'].to(device), data['y'].to(device)
            pred = model(X)
            
            test_loss += loss_fn(pred, y).item()
            mae_loss += torch.mean(torch.abs(denorm(pred) - denorm(y))).item()
    
    test_loss /= num_batches
    mae_loss /= num_batches
    print(f"Test loss: {test_loss:>8f}, MAE: {mae_loss:>8f}\n")

Now define the linear regression model in pytorch.

In [None]:
class MLPRegression(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super(MLPRegression, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x
    
tmp = next(mlp_train_loader.__iter__())
in_shape = tmp['X'].shape[1:]

mlp_model = MLPRegression(in_shape[0]).float().to(device)

output = mlp_model(tmp['X'].to(device))
print("MLP Regression Model output shape: ", output.shape)

Now train the lmlp regression model

In [None]:
mlp_loss_fn = nn.MSELoss()
mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=1e-3)

In [None]:
epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(mlp_train_loader, mlp_model, mlp_loss_fn, mlp_optimizer)
    test(mlp_val_loader, mlp_model, mlp_loss_fn, age_norm_stats)
print("Done!")

If all has gone well, you should see the loss decreasing on both the train and validation data

Lets check some predictions to visually inspect the models accuracy.

In [None]:
true_ages_list = []
predicted_ages_list = []

# Loop through the loader for a certain number of batches (e.g., 10 batches in this example)
for i, batch in enumerate(mlp_train_loader):
    with torch.no_grad():
        inputs = batch['X'].to(device)
        true_ages = batch['y'].cpu().numpy().reshape(-1)
        predicted_ages = mlp_model(inputs).detach().cpu().numpy().reshape(-1)
    true_ages_list.extend(true_ages)
    predicted_ages_list.extend(predicted_ages)

# denormalize the true ages
true_ages_list = np.array(true_ages_list) * age_norm_stats['std'] + age_norm_stats['mean']
predicted_ages_list = np.array(predicted_ages_list) * age_norm_stats['std'] + age_norm_stats['mean']

plt.figure(figsize=(8, 8))
plt.scatter(true_ages_list, predicted_ages_list, alpha=0.6)
plt.title('True Age vs Predicted Age for Several Batches')
plt.xlabel('True Age')
plt.ylabel('Predicted Age')
plt.plot([min(true_ages_list), max(true_ages_list)], [min(true_ages_list), max(true_ages_list)], color='red')  # Line of best fit
plt.grid(True)
plt.tight_layout()
plt.show()

### Baseline: Deep Neural Network Brain Volume Prediction

Now you must define and train your volume prediction model!

🚧 **Exercise 2.1** 🚧

Define the volume prediction model.

In [None]:
class VolumePredictor(nn.Module):
    def __init__(self, in_shape, out_size):
        super(VolumePredictor, self).__init__()
        ##################
        # Your code here #
        ##################

    def forward(self, x):
        
        ##################
        # Your code here #
        ##################

In [None]:
# Define model and run a batch through the model to check it works
tmp = next(vol_train_loader.__iter__())
in_shape = tmp['X'].shape[1:]
out_size = tmp['y'].shape[-1]

vol_model = VolumePredictor(in_shape, out_size).float().to(device)

input = tmp['X'].to(device)
output = vol_model(input)
print(output.shape)

#### Train volume prediction model

**Optional Exercise:** Feel free to modify the training loop, training parameters, etc, if there is anything you wish to improve

In [None]:
vol_loss_fn = nn.MSELoss()
vol_optimizer = torch.optim.Adam(vol_model.parameters(), lr=1e-3)

In [None]:
epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(vol_train_loader, vol_model, vol_loss_fn, vol_optimizer)
    test(vol_val_loader, vol_model, vol_loss_fn, volume_norm_stats)
print("Done!")

Visualize the volume predictions:

In [None]:
true_csfv_list = []
predicted_csfv_list = []

true_gmv_list = []
predicted_gmv_list = []

true_wmv_list = []
predicted_wmv_list = []

# Loop through the train loader 🚧 Exercise: Is the train data the best dataset to use to visualize model performance here? 🚧
for batch in vol_train_loader:
    with torch.no_grad():
        inputs = batch['X'].to(device)
        true_volumes = batch['y'].cpu().numpy()
        
        # Predict volumes
        predicted_volumes = vol_model(inputs).detach().cpu().numpy()
        
        # Denormalize predicted volumes and true volumes
        predicted_volumes = predicted_volumes * volume_norm_stats['std'] + volume_norm_stats['mean']
        true_volumes = true_volumes * volume_norm_stats['std'] + volume_norm_stats['mean']
        
        # Append true and predicted volumes to the respective lists
        true_csfv_list.extend(true_volumes[:, 0])
        predicted_csfv_list.extend(predicted_volumes[:, 0])
        
        true_gmv_list.extend(true_volumes[:, 1])
        predicted_gmv_list.extend(predicted_volumes[:, 1])
        
        true_wmv_list.extend(true_volumes[:, 2])
        predicted_wmv_list.extend(predicted_volumes[:, 2])

# Convert lists to numpy arrays
true_csfv_array = np.array(true_csfv_list)
predicted_csfv_array = np.array(predicted_csfv_list)

true_gmv_array = np.array(true_gmv_list)
predicted_gmv_array = np.array(predicted_gmv_list)

true_wmv_array = np.array(true_wmv_list)
predicted_wmv_array = np.array(predicted_wmv_list)

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))

# Plot true vs predicted volumes for csfv
axes[0].scatter(true_csfv_array, predicted_csfv_array, alpha=0.6)
axes[0].plot([min(true_csfv_array), max(true_csfv_array)], [min(true_csfv_array), max(true_csfv_array)], color='red')
axes[0].set_title('True CSFV vs Predicted CSFV')
axes[0].set_xlabel('True CSFV')
axes[0].set_ylabel('Predicted CSFV')
axes[0].grid(True)

# Plot true vs predicted volumes for gmv
axes[1].scatter(true_gmv_array, predicted_gmv_array, alpha=0.6)
axes[1].plot([min(true_gmv_array), max(true_gmv_array)], [min(true_gmv_array), max(true_gmv_array)], color='red')
axes[1].set_title('True GMV vs Predicted GMV')
axes[1].set_xlabel('True GMV')
axes[1].set_ylabel('Predicted GMV')
axes[1].grid(True)

# Plot true vs predicted volumes for wmv
axes[2].scatter(true_wmv_array, predicted_wmv_array, alpha=0.6)
axes[2].plot([min(true_wmv_array), max(true_wmv_array)], [min(true_wmv_array), max(true_wmv_array)], color='red')
axes[2].set_title('True WMV vs Predicted WMV')
axes[2].set_xlabel('True WMV')
axes[2].set_ylabel('Predicted WMV')
axes[2].grid(True)

plt.tight_layout()
plt.show()

### Baseline: Combine models to attain end-to-end predictions

Finally, we will combine the deep learning volume prediction model with the MLP regression model to obtain end-to-end predictions of age from MRI images.

Define the end-to-end validation dataset and data loaders:

In [None]:
class EndToEndDataset(Dataset):
    
    def __init__(
        self, 
        patient_list:List[str], 
        mri_file_dir:str,
        meta_df:pd.DataFrame,
        ages_norm_stats:Dict[str,np.array],
        out_spacing = OUT_SPACING, 
        out_size = OUT_SIZE,
    ):
        self.samples:List[Dict[str,torch.tensor]] = []
        
        for pat in tqdm(patient_list, desc='Loading Data'):
            
            # MRI images
            X = load_patient_mri_array(pat, mri_dir=mri_file_dir).squeeze()
            X = zero_mean_unit_var(X)
            if (out_spacing is not None) and (out_size is not None):
                X = resample_image(X, out_spacing=out_spacing, out_size=out_size)
            X = torch.from_numpy(X).unsqueeze(0).float()
            
            # Ages
            y = meta_df[meta_df["participant_id"] == pat][["age"]].values.squeeze()
            y = (y-ages_norm_stats['mean'])/ages_norm_stats['std']
            y = torch.tensor(y).unsqueeze(0).float()
            
            sample = {'X': X,  "y":y}
            self.samples.append(sample)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        return self.samples[item]

In [None]:
e2e_val_data = EndToEndDataset(
    patient_list=val_pats, 
    mri_file_dir=TRAIN_MRI_DIR,
    meta_df=train_meta_df,
    ages_norm_stats=age_norm_stats,
)

e2e_val_loader = DataLoader(e2e_val_data, batch_size=2)

tmp = next(e2e_val_loader.__iter__())
print(f"Dataloader has output type {type(tmp)} with keys {tmp.keys()}")
print(f"The input dimensions are: {tmp['X'].shape}")
print(f"The output dimensions are: {tmp['y'].shape}")

Define the combined model:

In [None]:
class CombinedModel(nn.Module):
    def __init__(self, volume_predictor, age_regressor):
        super(CombinedModel, self).__init__()
        self.volume_predictor = volume_predictor
        self.age_regressor = age_regressor

    def forward(self, x):
        x = self.volume_predictor(x)
        return self.age_regressor(x)

In [None]:
combined_model = CombinedModel(vol_model, mlp_model).float().to(device)

tmp = next(e2e_val_loader.__iter__())
output = combined_model(tmp['X'].to(device))

Assess performance of end-to-end model on the validation data:

In [None]:
e2e_loss_fn = nn.MSELoss()
test(e2e_val_loader, combined_model, e2e_loss_fn, age_norm_stats)

In [None]:
true_ages_list = []
predicted_ages_list = []

# Loop through the loader for a certain number of batches of the validation data (e.g., 10 batches in this example)
for i, batch in enumerate(e2e_val_loader):
    with torch.no_grad():
        inputs = batch['X'].to(device)
        true_ages = batch['y'].cpu().numpy().reshape(-1)
        predicted_ages = combined_model(inputs).detach().cpu().numpy().reshape(-1)
    true_ages_list.extend(true_ages)
    predicted_ages_list.extend(predicted_ages)

# denormalize the true ages
true_ages_list = np.array(true_ages_list) * age_norm_stats['std'] + age_norm_stats['mean']
predicted_ages_list = np.array(predicted_ages_list) * age_norm_stats['std'] + age_norm_stats['mean']

plt.figure(figsize=(8, 8))
plt.scatter(true_ages_list, predicted_ages_list, alpha=0.6)
plt.title('True Age vs Predicted Age')
plt.xlabel('True Age')
plt.ylabel('Predicted Age')
plt.plot([min(true_ages_list), max(true_ages_list)], [min(true_ages_list), max(true_ages_list)], color='red')  # Line of best fit
plt.grid(True)
plt.tight_layout()
plt.show()

#### 🚧 **Exercise 2.2** 🚧

Analyse you results and consider writting about the following in your report:

- Justification for your volume prediction model structure/design
- Discussion of results from volume prediction model
- Discussion of results from the MLP regression model
- Discusion of results from end-to-end combined model
- Discussion of reasonability of performance (and what the limitations may be and why)

You should provide evidence to backup your discussions and any conclusions. This may include showing performance on held out validation/test sets (using the metrics you chose previously), and other statistical tests or visualisations. You do not have to limit yourself to the visualisations/analysis already implemented in the notebook!

<Insert report here>

## Part 3: Improving upon the Baseline

Here you should make three meaningful attempts to improve upon the baseline model.

Start by analysing the performance of the baseline model and propose a hypothesis for where the model could be improved. The hypothesis could align to one of:
* Architecture: Would a different NN architecture be better? Maybe try training a deep learning model end to end, rather than first predicting volumes then predicting age.
* Hyperparameters: Is the learning rate set correctly? Should early stopping or other kinds of regularizations be used?
* Auxiliary losses: Can you use the extra information in the training data to provide richer training signals to the model?
* Data augmentation: Can simple augmentations improve performance? (This can be especially helpful when the dataset is small!)
* Skewed dataset: Are there techniques that can be used to account for the negative effects of a skewed/imbalanced dataset?

Now implement a new model based on your hypothesis!

Iterate though this procedure until you have proposed 3 hypotheses and developed 3 models.

Don’t worry if an experiment does not produce the intended results - write about why you think it didn’t produce those results! Note, that since this is deep learning, there may not always be an obvious explanation - in these cases so long as your initial hypothesis was valid and you have made attempts to find an explanation (where possible) you will not lose marks.

If you feel a set of experiments are leading you down a dead end - don’t worry! Write about why you feel that line of enquiry is not working, take a few steps back (even if that means going back to the baseline model) and start again for your next hypothesis. Failed experiments often yield interesting and insightful results!

**Further guidelines for the hypotheses:**

* Scope: Ensure your hypothesis is not too limited in scope. For example, simply changing the learning rate value once would be insufficient. Instead, you could try a sweep over learning rates.
* Grounded reasoning: Your reasoning and justifications should be grounded in what you have learned in the course materials/lectures/tutorials. 
* Evidence-backed conclusions: If you are making conclusions, ensure to present suitably strong evidence. If there is not enough evidence to make a strong conclusion, ensure you acknowledge this.

**Marking:**

In the report, marks will primarily be awarded for:
- The quality of your hypotheses, and their justifications (including how you move from one hypothesis/experiment to the next)
- The quality/thoroughness of the experiments you run to test your hypotheses, and your presentation and discussion of the results

You will __not__ be marked on the overall performance of your model. This coursework is designed to test your ability to propose reasonable experiments and to test your understanding of the content of the course.

## 🚧 Exercise 3.1: Hypothesis 1 🚧

When detailing your Hypothesis 1-3 in the report, some points to touch on include:

- Explain your hypothesis and the reasoning behind it (e.g. why do you think this could improve performance, how does it relate to your previous experiment).
- How do you intend to test the hypothesis (i.e. what experiments will you run)?
- What evidence is required to confirm/disprove the hypothesis?
- What do you hope to learn from your experiments?

In [None]:
########################################
# Test Hypothesis 1: INSERT CODE HERE
#
# Feel free to use as many cells as necessary
########################################

When discussing the results of your hypothesis 1-3 experiments, some points to touch on include:

- Are the results as expected?
- How strong are the conclusions you can draw, based on the evidence you have collected?
- Any interesting findings or potential interesting followup experiments.

## 🚧 Exercise 3.2: Hypothesis 2 🚧

In [None]:
##########################################
# Test Hypothesis 2: INSERT CODE HERE
##########################################

## 🚧 Exercise 3.3: Hypothesis 3 🚧

In [None]:
##########################################
# Test Hypothesis 3: INSERT CODE HERE
##########################################

## 🚧 Exercise 3.4: Concluding Discussion 🚧

Now you should have trained and validated 4 models (the baseline, and your 3 hypotheses). Below, you should do a final comparison of performances of the models **using the test data!**. This should be the **first and only time** you use the test data!

In your report, you should discuss potential reasons for differences in performance between all 4 models, and conclude which model is best (if any)

In [None]:
################################################
# INSERT CODE HERE
################################################