**Requirements + Installation**

In [None]:
!pip3 install -r requirements.txt

In [None]:
# Global Vars
GENERATE_MATLAB_DATA = False
DATA_PATH = '/home/cyu/workspace/202312-1-Outcome-Prediction-and-Consciousness-Detection-in-Patients-With-Acute-TBI/data/mri_data_pandas.pkl'
RESULT_PATH = '/home/cyu/workspace/202312-1-Outcome-Prediction-and-Consciousness-Detection-in-Patients-With-Acute-TBI/data/MRI_Model.keras'
CENTERING = True
CONDENSED = False
INPUT_RESOLUTION = (256, 256, 64)

In [None]:
import torch
from torchinfo import summary
from torch.utils.data import DataLoader, TensorDataset 
from vit_pytorch.vit_3d import ViT
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm, trange


In [None]:
# load Pandas dataframe
if GENERATE_MATLAB_DATA:
    drive.mount('/content/drive')
file_name = DATA_PATH
df_loaded = pd.read_pickle(file_name)

**Pre-Process Data**

In [None]:
def get_patient_list(mat_file):
  """
  Input: mat_file - matlab file
  Output: patient_list - set of patient IDs
  """
  patient_list = []
  for line in mat_file['image_data'][0]:
    patient_list.append(line[0][0])

  patient_list = set(patient_list)

  return patient_list


def map_patient_to_img_technique(mat_file, patient_list):
  """
  Input: mat_file - matlab file
         patient_list - set or list of patients
  Output: patient_img_technique_map - dict of patients with their MRI image type
  """
  patient_img_technique_map = {}

  for patient_id in patient_list:
    img_techniques = []
    for line in mat_file['image_data'][0]:
      if line[0][0] == patient_id and line[3][0] not in img_techniques:
        img_techniques.append(line[3][0])
    patient_img_technique_map[patient_id] = img_techniques

  return patient_img_technique_map


def rescale_img(img_array, resolution):
  """
  Input: img_array - 2D numpy image array
         resolution - tuple of (height, width)
  Output: res - rescaled 2D numpy image array
  """
  if img_array.shape != resolution:
    res = cv2.resize(img_array, dsize=resolution, interpolation=cv2.INTER_CUBIC)
  else:
    res = img_array
  return res


def stack_mri_slices(img_slices):
  """
  Input: img_slices - numpy array of images to stack

  Output: img_3D - stacked MRI images
  """
  img_3D = np.dstack(img_slices)
  img_3D = img_3D[:, :, :, np.newaxis]
  return img_3D

# print(df_loaded['Technique'].value_counts())

# rescale images to desired resolution
TARGET_RESOLUTION = (INPUT_RESOLUTION[0], INPUT_RESOLUTION[1])
for idx in df_loaded.index:
  try:
    df_loaded['Data'][idx] = rescale_img(df_loaded['Data'][idx], TARGET_RESOLUTION)
    # print(df_loaded['Data'][idx].shape)
  except:
    print('failed at: ', idx)

# get list of patients and shuffle
patient_list = df_loaded['Patient_ID'].unique() # get a list of patients
# print(patient_list)
patient_list = np.delete(patient_list, 58)      # remove a patient ID
patient_list = np.delete(patient_list, np.where(patient_list == '02445263'))
patient_list = np.delete(patient_list, np.where(patient_list == '15816944'))
# print(len(patient_list))
# print(patient_list)
np.random.seed(42)
np.random.shuffle(patient_list)                 # shuffle the patient list
# print(patient_list)

train_patients = patient_list[0:46]   # 70% of patients for training
# print(train_patients)
test_patients = patient_list[46:]     # 30% of patients for testing
# print(test_patients)

# build training/testing dataset

x_data_train_dict = {}
y_data_train_dict = {}
x_data_test_dict = {}
y_data_test_dict = {}
x_data_all_dict = {}
y_data_all_dict = {}

num_slices = INPUT_RESOLUTION[2]

# TRAINING data x
for patient_id in train_patients:
  try:
    patient_df = df_loaded.loc[df_loaded['Patient_ID'] == patient_id]
    patient_mri_data = patient_df.loc[patient_df['Technique'] == 'SWAN']
    num_mri_slices = patient_mri_data['Data'].shape[0]
    # print(num_mri_slices)

    start_idx = int(num_mri_slices/2) - 32
    end_idx = int(num_mri_slices/2) +32
    print('start: ', start_idx, 'end: ', end_idx)

    if CENTERING:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[start_idx:end_idx])
      # print(patient_mri_data.shape)
    else:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[0:num_slices])
      # print(patient_mri_data.shape)
    x_data_train_dict[patient_id] = patient_mri_data

  except:
    print('no data: ', patient_id)
    patient_df = df_loaded.loc[df_loaded['Patient_ID'] == patient_id]
    patient_mri_data = patient_df.loc[patient_df['Technique'] == 'Ax DWI Asset']
    num_mri_slices = patient_mri_data['Data'].shape[0]
    # print(num_mri_slices)

    start_idx = int(num_mri_slices/2) - 32
    end_idx = int(num_mri_slices/2) + 32
    print('start: ', start_idx, 'end: ', end_idx)

    if CENTERING:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[start_idx:end_idx])
      # print(patient_mri_data.shape)
    else:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[0:num_slices])
      # print(patient_mri_data.shape)      
    x_data_train_dict[patient_id] = patient_mri_data

# print(len(x_data_train_dict))
# print(x_data_train_dict)

# TRAINING data y
for patient_id in train_patients:
  outcome = df_loaded.loc[df_loaded['Patient_ID'] == patient_id, 'Designator'].iloc[0]
  if outcome == 'responsive':
    y_data_train_dict[patient_id] = 1
  elif outcome == 'unresponsive':
    y_data_train_dict[patient_id] = 0

# print(len(y_data_train_dict))
# print(y_data_train_dict)

# TESTING data x
# print(len(test_patients))
for patient_id in test_patients:
  # print(patient_id)
  try:
    patient_df = df_loaded.loc[df_loaded['Patient_ID'] == patient_id]
    patient_mri_data = patient_df.loc[patient_df['Technique'] == 'SWAN']
    num_mri_slices = patient_mri_data['Data'].shape[0]
    # print(num_mri_slices)

    start_idx = int(num_mri_slices/2) - 32
    end_idx = int(num_mri_slices/2) +32
    print('start: ', start_idx, 'end: ', end_idx)

    if CENTERING:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[start_idx:end_idx])
      # print(patient_mri_data.shape)
    else:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[0:num_slices])
      # print(patient_mri_data.shape)    
    x_data_test_dict[patient_id] = patient_mri_data

  except:
    # print('no data: ', patient_id)
    patient_df = df_loaded.loc[df_loaded['Patient_ID'] == patient_id]
    patient_mri_data = patient_df.loc[patient_df['Technique'] == 'Ax DWI Asset']
    num_mri_slices = patient_mri_data['Data'].shape[0]
    # print(num_mri_slices)

    start_idx = int(num_mri_slices/2) - 32
    end_idx = int(num_mri_slices/2) +32
    print('start: ', start_idx, 'end: ', end_idx)

    if CENTERING:    
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[start_idx:end_idx])
      # print(patient_mri_data.shape)
    else:
      patient_mri_data = stack_mri_slices(patient_mri_data['Data'].to_numpy()[0:num_slices])
      # print(patient_mri_data.shape)    
    x_data_test_dict[patient_id] = patient_mri_data

# print(len(x_data_test_dict))
# print(x_data_test_dict)


# TESTING data y
for patient_id in test_patients:
  outcome = df_loaded.loc[df_loaded['Patient_ID'] == patient_id, 'Designator'].iloc[0]
  if outcome == 'responsive':
    y_data_test_dict[patient_id] = 1
  elif outcome == 'unresponsive':
    y_data_test_dict[patient_id] = 0

#print(len(y_data_test_dict))
# print(y_data_train_dict)

# print(patient_mri_data.shape)
# print(patient_df['Technique'].value_counts())

#print('x data train shape:')
#for patient in x_data_train_dict:
#  print(x_data_train_dict[patient].shape)

#print('x data test shape:')
#for patient in x_data_test_dict:
#  print(x_data_test_dict[patient].shape)

x_train = []
y_train = []
x_test = []
y_test = []

for patient_id in train_patients:
  x_train.append(x_data_train_dict[patient_id])
  y_train.append(y_data_train_dict[patient_id])

for patient_id in test_patients:
  x_test.append(x_data_test_dict[patient_id])
  y_test.append(y_data_test_dict[patient_id])

**Load Data into Pytorch Data Structures**

In [None]:
print(torch.cuda.device_count())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

x_train = np.asarray(x_train)
print(x_train.shape)
x_train_T = x_train.transpose(0, 4, 3, 1, 2)
print(x_train_T.shape)
y_train = np.asarray(y_train)
print(y_train.shape)

x_test = np.asarray(x_test)
print(x_test.shape)
x_test_T = x_test.transpose(0, 4, 3, 1, 2)
print(x_test_T.shape)
y_test = np.asarray(y_test)
print(y_test.shape)

tensor_x_train = torch.Tensor(x_train_T).to(device)
# print(tensor_x_train)
tensor_y_train = torch.Tensor(y_train).to(device)
tensor_y_train = tensor_y_train.type(torch.long)
# print(tensor_y_train)
train_dataset = TensorDataset(tensor_x_train, tensor_y_train)
# print(train_dataset.tensors)
train_loader = DataLoader(train_dataset, shuffle=True)

tensor_x_test = torch.Tensor(x_test_T).to(device)
# print(tensor_x_test)
tensor_y_test = torch.Tensor(y_test).to(device)
tensor_y_test = tensor_y_test.type(torch.long)
# print(tensor_y_test)
test_dataset = TensorDataset(tensor_x_test, tensor_y_test)
# print(test_dataset.tensors)
test_loader = DataLoader(test_dataset, shuffle=False)

print('Training set has {} instances'.format(len(train_dataset)))
print('Validation set has {} instances'.format(len(test_dataset)))

**Define 3-D Vision Transformer Model**

In [None]:
vit_model = ViT(
                image_size = INPUT_RESOLUTION[0],          # image size
                frames = INPUT_RESOLUTION[2],              # number of frames
                image_patch_size = 16,                     # image patch size
                frame_patch_size = 4,                      # frame patch size
                num_classes = 2,
                dim = 1024,
                depth = 6,
                heads = 8,
                mlp_dim = 2048,
                dropout = 0.1,
                emb_dropout = 0.1,
                channels=1
                ).to(device)

summary(model=vit_model)

**Training Loop**

In [None]:
# Training Parameters
N_EPOCHS = 50
LR = 0.005

losses = []
accuracy_vals = []

optimizer = Adam(vit_model.parameters(), lr=LR)
criterion = CrossEntropyLoss().to(device)
for epoch in trange(N_EPOCHS, desc="Training"):
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        x, y = x, y
        y_hat = vit_model(x).to(device)
        loss = criterion(y_hat, y)

        epoch_loss += loss.detach().cpu().item() / len(train_loader)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (y_hat.argmax(dim=1) == y).float().mean()
        epoch_accuracy += acc.cpu() / len(train_loader)

    losses.append(epoch_loss)
    accuracy_vals.append(epoch_accuracy)

    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {epoch_loss:.2f}, accuracy: {epoch_accuracy:.2f}")


In [None]:
# Plot accuracy/loss
print('accuracy vals: {}', accuracy_vals)
print('loss vals: {}', losses)

# summarize history for accuracy
plt.plot(accuracy_vals)
plt.ylim([0, 1])
plt.title('MRI model accuracy, 3-D ViT')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train'], loc='upper left')
plt.show()

# summarize history for loss
plt.plot(losses)
plt.ylim([0, 1])
plt.title('MRI model loss, 3-D ViT')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train'], loc='upper left')
plt.show()

**Test Loop**

In [None]:
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = vit_model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")