# Parcellation
After finishing pre-processing, we perform parcellation with the A424 atlas and process the data in different normalization methods.

In [1]:
import os
import math
from random import randint, seed
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_from_disk, concatenate_datasets
from transformers import ViTImageProcessor, ViTMAEConfig
from brainlm_mae.modeling_vit_mae_with_padding import ViTMAEForPreTraining 
from brainlm_mae.replace_vitmae_attn_with_flash_attn import replace_vitmae_attn_with_flash_attn
from utils.utils import convert_fMRIvols_to_A424, process_datasets

In [None]:
raw_data_dir = "/path/to/raw_fMRI_data"
save_data_dir = "/path/to/a424_fMRI_data" #Make sure this directory exists.
args = {
    "ts_data_dir": "/path/to/a424_fMRI_data",     # "Path to directory containing dat files, A424 coordinates file, and A424 excel sheet.",
    "processed_data_dir": os.path.join(save_data_dir, "processed"),     # "The directory where you want to save the output arrow datasets."
    "dataset_name": "xxx",
    "metadata_path": "path/to/metadata.csv"
}

# Convert fMRI volumes to A424 time-series data
convert_fMRIvols_to_A424(data_path=raw_data_dir, output_path=save_data_dir)

# Processing datasets
process_datasets(args, args["processed_data_dir"])

# Training Pipeline
## Setup args

In [2]:
from datetime import datetime
moving_window_len = 200
kfold = 5
batch_size = 8
data_path = "/path/to/a424_fMRI_data/processed"
output_path = "output/"
now = datetime.now()
dt_string = now.strftime("%Y-%m-%d-%H_%M_%S")
output_path = os.path.join(output_path, dt_string)
if not os.path.exists(output_path):
    os.makedirs(output_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading dataset

In [3]:
import random

def preprocess_fmri(examples, recording_col_name, variable_of_interest_col_name="Response", moving_window_len=200):
    """
    Preprocessing function for dataset samples. This function is passed into Trainer as
    a preprocessor which takes in one row of the loaded dataset and constructs a model
    input sample according to the arguments which model.forward() expects.

    The reason this function is defined inside on main() function is because we need
    access to arguments such as cell_expression_vector_col_name.
    """
    label = examples[variable_of_interest_col_name][0]
    brain_net = examples["Brain_Network"]
    if math.isnan(label):
        label = -1
    else:
        label = int(label)
    label = torch.tensor(label, dtype=torch.int64)
    signal_vector = examples[recording_col_name]
    signal_vector = torch.tensor(signal_vector, dtype=torch.float32)

    # Choose random starting index, take window of moving_window_len points for each region
    start_idx = random.randint(0, signal_vector.shape[1] - moving_window_len)
    end_idx = start_idx + moving_window_len
    signal_window = signal_vector[:, start_idx: end_idx]
    
    
    # Append signal values and coords
    window_xyz_list = []
    for brain_region_idx in range(signal_window.shape[0]):

        # Append voxel coordinates
        xyz = torch.tensor([
            coords_ds[brain_region_idx]["X"],
            coords_ds[brain_region_idx]["Y"],
            coords_ds[brain_region_idx]["Z"]
        ], dtype=torch.float32)
        window_xyz_list.append(xyz)
    window_xyz_list = torch.stack(window_xyz_list)

    # Add in key-value pairs for model inputs which CellLM is expecting in forward() function:
    #  signal_vectors and xyz_vectors
    #  These lists will be stacked into torch Tensors by collate() function (defined above).
    examples["signal_vectors"] = signal_window.unsqueeze(0)
    examples["xyz_vectors"] = window_xyz_list.unsqueeze(0)
    examples["brain_network"] = np.array(brain_net)
    examples["label"] = label
    return examples

class fMRIDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, recording_col_name, variable_of_interest_col_name, moving_window_len=200):
        self.dataset = dataset
        self.recording_col_name = recording_col_name
        self.variable_of_interest_col_name = variable_of_interest_col_name
        self.moving_window_len = moving_window_len
        self.features, self.labels = self._load_data()

    def __len__(self):
        return self.dataset.num_rows
    
    def _load_data(self):
        features = []
        labels = []
        for recording_idx in range(self.dataset.num_rows):
            example1 = concat_ds[recording_idx]
            features.append(example1)
            labels.append(example1[self.variable_of_interest_col_name])
            # Wrap each value in the key:value pairs into a list (expected by preprocess() and collate())
            
        return features, labels

    def __getitem__(self, idx):
        example = self.features[idx]
        processed_example = preprocess_fmri(example, self.recording_col_name, moving_window_len=self.moving_window_len)

        return {
            "signal_vectors": processed_example["signal_vectors"].squeeze(0),
            "xyz_vectors": processed_example["xyz_vectors"].squeeze(0),
            "input_ids": processed_example["signal_vectors"].squeeze(0),
            "labels": processed_example["label"].squeeze(0),
            "brain_network": processed_example["brain_network"]
        }
    


def collate_fn(example):
    """
    This function tells the dataloader how to stack a batch of examples from the dataset.
    Need to stack gene expression vectors and maintain same argument names for model inputs
    which CellLM is expecting in forward() function:
        expression_vectors, sampled_gene_indices, and cell_indices
    """
    # These inputs will go to model.forward(), names must match
    return {
        "signal_vectors": torch.stack([e["signal_vectors"] for e in example]),
        "xyz_vectors": torch.stack([e["xyz_vectors"] for e in example]),
        "input_ids": torch.stack([e["signal_vectors"] for e in example]),
        "labels": torch.stack([e["labels"] for e in example]),
        "brain_network": torch.stack([torch.tensor(e["brain_network"]) for e in example])
    }

In [None]:
from datasets import load_from_disk, concatenate_datasets
train_ds = load_from_disk(os.path.join(data_path, "data"))
coords_ds = load_from_disk(os.path.join(data_path, "Brain_Region_Coordinates"))
concat_ds = concatenate_datasets([train_ds])

variable_of_interest_col_name = "Response"
col_name = "Raw_Recording"
dataset = fMRIDataset(concat_ds, col_name, "Response", moving_window_len=moving_window_len)

## Model definition

In [None]:
from brainlm_mae.modeling_brainlm import BrainLMForPretraining
class MultimodalfMRI(nn.Module):
    def __init__(self, resnet):
        super(MultimodalfMRI, self).__init__()
        self.resnet = resnet
        self.classifier = nn.Linear(512, 2)

    def forward(self, signal_vectors, vit_cls_token):
        resnet_input = torch.bmm(signal_vectors, signal_vectors.permute(0, 2, 1)).unsqueeze(1).repeat(1, 3, 1, 1).float()
        cls_token = torch.concat([vit_cls_token, self.resnet(resnet_input)], dim=1)
        logits = self.classifier(cls_token)
        return logits
    
model = BrainLMForPretraining.from_pretrained("pretrained_models/brainlm")
model.to(device)
for param in model.parameters():
    param.requires_grad = False

model.vit.embeddings.mask_ratio = 0.0
model.vit.embeddings.config.mask_ratio = 0.0

## Start training

In [None]:

from sklearn.model_selection import StratifiedKFold
from torchvision.models import resnet18
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, balanced_accuracy_score
from sklearn.metrics import confusion_matrix, matthews_corrcoef


metric_df = pd.DataFrame(columns=["kfold", "epoch", "F1", "Accuracy", "BACC", "ROC AUC", "Recall", "Precision", "Sensitivity", "Specificity", "MCC", "TP", "FP", "TN", "FN"])
for i, (train_idx, val_idx) in enumerate(StratifiedKFold(n_splits=kfold, shuffle=False).split(dataset.features, dataset.labels)):
            
    model_resnet = resnet18(pretrained=False)
    model_resnet.fc = nn.Sequential(
        nn.Flatten(),
        nn.Linear(model_resnet.fc.in_features, 256),
    )
    model_resnet = MultimodalfMRI(model_resnet).to(device)
    optimizer = torch.optim.AdamW(model_resnet.parameters(), lr=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-7)
    

    train_set = torch.utils.data.Subset(dataset, train_idx)
    val_set = torch.utils.data.Subset(dataset, val_idx)
    trainloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    valloader = torch.utils.data.DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    
    best_mcc = 0
    best_epoch = 0
    model.eval()
    loop = tqdm(range(100))
    for epoch in loop:
        loop.set_description(f"Fold: {i} Epoch {epoch}")
        metrics = []
        metrics.append(i)
        optimizer.zero_grad()
        model_resnet.train()
        
        for example in trainloader:
            with torch.no_grad():
                encoder_output = model.vit(
                    signal_vectors=example["signal_vectors"].to(device),
                    xyz_vectors=example["xyz_vectors"].to(device),
                    output_attentions=True,
                    output_hidden_states=True
                )
            logits = model_resnet(example["signal_vectors"].to(device), encoder_output.last_hidden_state[:,0,:])
            loss = nn.CrossEntropyLoss()(logits, example["labels"].to(device))
            loss.backward()
            optimizer.step()

        model_resnet.eval()
        
        losses = []
        logits_list = []
        labels_list = []
        for example in valloader:
            with torch.no_grad():
                encoder_output = model.vit(
                    signal_vectors=example["signal_vectors"].to(device),
                    xyz_vectors=example["xyz_vectors"].to(device),
                    output_attentions=True,
                    output_hidden_states=True
                )
                logits = model_resnet(example["signal_vectors"].to(device), encoder_output.last_hidden_state[:,0,:])
                loss = nn.CrossEntropyLoss()(logits, example["labels"].to(device))
                losses.append(loss.item())
                logits_list.append(logits.detach().cpu().numpy())
                labels_list.append(example["labels"].detach().cpu().numpy())
        logits_list = np.concatenate(logits_list, axis=0)
        labels_list = np.concatenate(labels_list)
        preds_list = np.argmax(logits_list, axis=1)
        acc = accuracy_score(labels_list, preds_list)
        bacc = balanced_accuracy_score(labels_list, preds_list)
        f1 = f1_score(labels_list, preds_list, average='macro')
        roc_auc = roc_auc_score(labels_list, logits_list[:, 1])
        cm = confusion_matrix(labels_list, preds_list)
        cm_percent = cm / cm.sum(axis=1)[:, np.newaxis]
        mcc = matthews_corrcoef(labels_list, preds_list)
        tp, fp, tn, fn = cm[1, 1], cm[0, 1], cm[0, 0], cm[1, 0]
        metrics.append(epoch)
        metrics.append(f1)
        metrics.append(acc)
        metrics.append(bacc)
        metrics.append(roc_auc)
        metrics.append(tp/(tp+fn+1e-9))
        metrics.append(tp/(tp+fp+1e-9))
        metrics.append(tp/(tp+fn+1e-9))
        metrics.append(tn/(tn+fp+1e-9))
        metrics.append(mcc)
        metrics.append(tp)
        metrics.append(fp)
        metrics.append(tn)
        metrics.append(fn)
        loop.set_postfix_str(f"Best MCC: {best_mcc:.4f}")
        metric_df.loc[len(metric_df)] = metrics
        
        if mcc > best_mcc:
            best_mcc = mcc
            best_epoch = epoch
            save_dict = {
                "epoch": epoch,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "model_resnet": model_resnet.state_dict(),
                "best_mcc": best_mcc,
                "best_epoch": best_epoch
            }
            torch.save(save_dict, os.path.join(output_path, f"best_mcc.pt"))
        save_dict = {
            "epoch": epoch,
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "model_resnet": model_resnet.state_dict(),
            "best_mcc": best_mcc,
            "best_epoch": best_epoch
        }
        torch.save(save_dict, os.path.join(output_path, f"last_model.pt"))
        scheduler.step()
    metric_df.to_csv(os.path.join(output_path, f"results.csv"))