In [None]:
import numpy as np
import os
import torch
from PIL import Image
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
from scipy.stats import pearsonr as corr

### TO DO:
- Improve training features (accelerator, lr scheduler, etc)

In [None]:
print('Pulling NSD webdataset data...')
# Note: using "voxel" naming even though we use vertices here... makes it easier porting over MindEye lingo

train_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/wds/subj01_{3..98}.tar"
val_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/wds/subj01_{0..2}.tar"
meta_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/wds/metadata_subj01.json"

metadata = json.load(open(meta_url))
num_train = metadata['total'] - 300
num_val = 300
batch_size = 32
num_devices = 1
seed = 42

print('Prepping train and validation dataloaders...')
import math
import random
import webdataset as wds
def my_split_by_node(urls):
    return urls

num_workers = 10

global_batch_size = batch_size * num_devices
num_batches = math.floor(num_train / global_batch_size)
num_worker_batches = math.floor(num_batches / num_workers)

train_data = wds.WebDataset(train_url, resampled=False)\
    .shuffle(500, initial=500, rng=random.Random(seed))\
    .decode("torch")\
    .rename(images="jpg;png", voxels="vert.npy", clip_latent="clip_emb_final.npy", clip_last_hidden="clip_emb_hidden.npy", imagebind_latent='imagebind_final.npy', imagebinde_hidden='imagebind_hidden.npy')\
    .to_tuple("voxels", "images", "clip_latent", "clip_last_hidden", "imagebind_latent", "imagebinde_hidden")\
    .batched(batch_size, partial=False)\
    .with_epoch(num_worker_batches)

train_dl = torch.utils.data.DataLoader(train_data, num_workers=num_workers,
                        batch_size=None, shuffle=False, persistent_workers=True)

global_batch_size = batch_size
num_workers_val = 1

num_batches_val = math.ceil(num_val / global_batch_size)
num_worker_batches_val = math.ceil(num_batches_val / num_workers)
print("validation: num_worker_batches", num_worker_batches_val)

val_data = wds.WebDataset(val_url, resampled=False, nodesplitter=my_split_by_node)\
    .decode("torch")\
    .rename(images="jpg;png", voxels="vert.npy", clip_latent="clip_emb_final.npy", clip_last_hidden="clip_emb_hidden.npy", imagebind_latent='imagebind_final.npy', imagebinde_hidden='imagebind_hidden.npy')\
    .to_tuple("voxels", "images", "clip_latent", "clip_last_hidden", "imagebind_latent", "imagebinde_hidden")\
    .batched(300, partial=False)

val_dl = torch.utils.data.DataLoader(val_data, num_workers=num_workers_val,
                    batch_size=None, shuffle=False, persistent_workers=True)

In [None]:
voxels_shape, images_shape, clip_latent_shape, clip_last_hidden_shape, imagebind_latent_shape, imagebind_hidden_shape = None, None, None, None, None, None
for voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden in tqdm(train_dl):
    voxels_shape = voxels.shape
    images_shape = images.shape
    clip_latent_shape = clip_latent.shape
    clip_last_hidden_shape = clip_last_hidden.shape
    imagebind_latent_shape = imagebind_latent.shape
    imagebind_hidden_shape = imagebind_hidden.shape
    break

print('Val dataloader shapes:')
print('voxels', voxels_shape)
print('images', images_shape)
print('clip_latent', clip_latent_shape)
print('clip_last_hidden', clip_last_hidden_shape)
print('imagebind_latent', imagebind_latent_shape)
print('imagebind_hidden', imagebind_hidden_shape)


In [None]:
voxels_shape, images_shape, clip_latent_shape, clip_last_hidden_shape, imagebind_latent_shape, imagebind_hidden_shape = None, None, None, None, None, None
for voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden in val_dl:
    voxels_shape = voxels.shape
    images_shape = images.shape
    clip_latent_shape = clip_latent.shape
    clip_last_hidden_shape = clip_last_hidden.shape
    imagebind_latent_shape = imagebind_latent.shape
    imagebind_hidden_shape = imagebind_hidden.shape
    break

print('Val dataloader shapes:')
print('voxels', voxels_shape)
print('images', images_shape)
print('clip_latent', clip_latent_shape)
print('clip_last_hidden', clip_last_hidden_shape)
print('imagebind_latent', imagebind_latent_shape)
print('imagebind_hidden', imagebind_hidden_shape)

In [None]:

learning_rate = 0.001
alpha = 1e-8
num_epochs = 10
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

# Define the model
class MLP(torch.nn.Module):
    def __init__(self, input_size, output_size, hidden_size=5000):
        super(MLP, self).__init__()
        self.linear_in = torch.nn.Linear(input_size, hidden_size)
        self.linear_hid = torch.nn.Linear(hidden_size, output_size)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.linear_hid(self.relu(self.linear_in(x)))



def train_the_model_and_get_predictions(map_name, model, save = True):
    # outputs name can be 'clip_latent', 'clip_last_hidden', 'imagebind_latent', 'imagebind_hidden'
    map_shape = None
    if map_name == 'clip_latent':
        map_shape = clip_latent_shape
    elif map_name == 'clip_last_hidden':
        map_shape = (clip_last_hidden_shape[0], clip_last_hidden_shape[1], clip_last_hidden_shape[2] * clip_last_hidden_shape[3])
    elif map_name == 'imagebind_latent':
        map_shape = imagebind_latent_shape
    elif map_name == 'imagebind_last_hidden':
        map_shape = (clip_last_hidden_shape[0], imagebind_hidden_shape[1], imagebind_hidden_shape[2] * imagebind_hidden_shape[3])
    else:
        raise Exception('Invalid map_name')
    

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    preds = []
    actuals = []
    best_loss = 1000000

    # Train the right hemisphere model
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for i, (voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden) in enumerate(train_dl):

            # Move tensors to the configured device
            labels = voxels.to(device).float().mean(dim=1)  # TODO: I'm just taking the mean, Pauls said if you use single records you get better performance

            if map_name == 'clip_latent':
                inputs = clip_latent.to(device).float().squeeze(1)
            elif map_name == 'clip_last_hidden':
                inputs = clip_last_hidden.to(device).float().squeeze(1)
                inputs = inputs.view(inputs.shape[0], -1)
            elif map_name == 'imagebind_latent':
                inputs = imagebind_latent.to(device).float().squeeze(1)
            elif map_name == 'imagebind_last_hidden':
                inputs = imagebind_hidden.to(device).float().squeeze(1)
                inputs = inputs.view(inputs.shape[0], -1)
            else:
                raise Exception('Invalid map_name')

            # Forward pass
            outputs = model(inputs)
            loss = torch.nn.functional.mse_loss(outputs, labels) # TODO: Try contrastive loss or any other loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()

            
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, epoch_loss / num_batches))

        test_loss = 0
        model.eval()
        with torch.no_grad():

            for i, (voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden) in enumerate(val_dl):
                
                # Move tensors to the configured device
                labels = voxels.to(device).float().mean(dim=1).squeeze(1) # TODO: I'm just taking the mean, Pauls said if you use single records you get better performance

                if map_name == 'clip_latent':
                    inputs = clip_latent.to(device).float().squeeze(1)
                elif map_name == 'clip_last_hidden':
                    inputs = clip_last_hidden.to(device).float().squeeze(1)
                    inputs = inputs.view(inputs.shape[0], -1)
                elif map_name == 'imagebind_latent':
                    inputs = imagebind_latent.to(device).float().squeeze(1)
                elif map_name == 'imagebind_last_hidden':
                    inputs = imagebind_hidden.to(device).float().squeeze(1)
                    inputs = inputs.view(inputs.shape[0], -1)
                else:
                    raise Exception('Invalid map_name')
                

                # Forward pass
                outputs = model(inputs)
                loss = torch.nn.functional.mse_loss(outputs, labels)
                
                test_loss += loss.item()

                if test_loss < best_loss:
                    best_loss = test_loss
                    if save:
                        torch.save(model.state_dict(), 'mlpmodels/best_model_{}.ckpt'.format(map_name))
                    preds = outputs.cpu().numpy()
                    actuals = labels.cpu().numpy()

            print('Epoch [{}/{}], Test Loss: {:.4f}'.format(epoch+1, num_epochs, test_loss / 1))

    preds = np.array(preds)
    actuals = np.array(actuals)

    return preds, actuals, model




In [None]:
# Train the model and get the predictions for final clip

learning_rate = 1e-5
num_epochs = 30
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") # Change for the gpu you're using

model = MLP(clip_last_hidden_shape[-1] * clip_last_hidden_shape[-2] , voxels_shape[-1])
preds, actuals, model = train_the_model_and_get_predictions('clip_last_hidden', model, save = False)

y_val_right_pred = preds[:, 20544:]
y_val_left_pred = preds[:, :20544]

rh_fmri_val = actuals[:, 20544:]
lh_fmri_val = actuals[:, :20544]

# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(y_val_left_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_left_pred.shape[1])):
    lh_correlation[v] = corr(y_val_left_pred[:,v], lh_fmri_val[:,v])[0]

# Empty correlation array of shape: (RH vertices)
rh_correlation = np.zeros(y_val_right_pred.shape[1])
# Correlate each predicted RH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_right_pred.shape[1])):
    rh_correlation[v] = corr(y_val_right_pred[:,v], rh_fmri_val[:,v])[0]
    
print("Score: ", "\nRight:", rh_correlation.mean(),"\nLeft:", lh_correlation.mean())

In [None]:
# Train the model and get the predictions for final clip

learning_rate = 1e-5
num_epochs = 30
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") # Change for the gpu you're using

model = MLP(clip_last_hidden_shape[-1] * clip_last_hidden_shape[-2] , voxels_shape[-1])
preds, actuals, model = train_the_model_and_get_predictions('imagebind_last_hidden', model, save = False)

y_val_right_pred = preds[:, 20544:]
y_val_left_pred = preds[:, :20544]

rh_fmri_val = actuals[:, 20544:]
lh_fmri_val = actuals[:, :20544]

# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(y_val_left_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_left_pred.shape[1])):
    lh_correlation[v] = corr(y_val_left_pred[:,v], lh_fmri_val[:,v])[0]

# Empty correlation array of shape: (RH vertices)
rh_correlation = np.zeros(y_val_right_pred.shape[1])
# Correlate each predicted RH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_right_pred.shape[1])):
    rh_correlation[v] = corr(y_val_right_pred[:,v], rh_fmri_val[:,v])[0]
    
print("Score: ", "\nRight:", rh_correlation.mean(),"\nLeft:", lh_correlation.mean())