In [1]:
import numpy as np
import os
import torch
from PIL import Image
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
from scipy.stats import pearsonr as corr

### TO DO:
- Improve training features (accelerator, lr scheduler, etc)

In [2]:
print('Pulling NSD webdataset data...')
# Note: using "voxel" naming even though we use vertices here... makes it easier porting over MindEye lingo

train_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/wds/subj01_{3..98}.tar"
val_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/wds/subj01_{0..2}.tar"
meta_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/wds/metadata_subj01.json"

metadata = json.load(open(meta_url))
num_train = metadata['total'] - 300
num_val = 300
batch_size = 32
num_devices = 1
seed = 42

print('Prepping train and validation dataloaders...')
import math
import random
import webdataset as wds
def my_split_by_node(urls):
    return urls

num_workers = 10

global_batch_size = batch_size * num_devices
num_batches = math.floor(num_train / global_batch_size)
num_worker_batches = math.floor(num_batches / num_workers)

train_data = wds.WebDataset(train_url, resampled=False)\
    .shuffle(500, initial=500, rng=random.Random(seed))\
    .decode("torch")\
    .rename(images="jpg;png", voxels="vert.npy", clip_latent="clip_emb_final.npy", clip_last_hidden="clip_emb_hidden.npy", imagebind_latent='imagebind_final.npy', imagebinde_hidden='imagebind_hidden.npy')\
    .to_tuple("voxels", "images", "clip_latent", "clip_last_hidden", "imagebind_latent", "imagebinde_hidden")\
    .batched(batch_size, partial=False)\
    .with_epoch(num_worker_batches)

train_dl = torch.utils.data.DataLoader(train_data, num_workers=num_workers,
                        batch_size=None, shuffle=False, persistent_workers=True)

global_batch_size = batch_size
num_workers_val = 1

num_batches_val = math.ceil(num_val / global_batch_size)
num_worker_batches_val = math.ceil(num_batches_val / num_workers)
print("validation: num_worker_batches", num_worker_batches_val)

val_data = wds.WebDataset(val_url, resampled=False, nodesplitter=my_split_by_node)\
    .decode("torch")\
    .rename(images="jpg;png", voxels="vert.npy", clip_latent="clip_emb_final.npy", clip_last_hidden="clip_emb_hidden.npy", imagebind_latent='imagebind_final.npy', imagebinde_hidden='imagebind_hidden.npy')\
    .to_tuple("voxels", "images", "clip_latent", "clip_last_hidden", "imagebind_latent", "imagebinde_hidden")\
    .batched(300, partial=False)

val_dl = torch.utils.data.DataLoader(val_data, num_workers=num_workers_val,
                    batch_size=None, shuffle=False, persistent_workers=True)

Pulling NSD webdataset data...
Prepping train and validation dataloaders...
validation: num_worker_batches 1


In [3]:
voxels_shape, images_shape, clip_latent_shape, clip_last_hidden_shape, imagebind_latent_shape, imagebind_hidden_shape = None, None, None, None, None, None
for voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden in tqdm(train_dl):
    voxels_shape = voxels.shape
    images_shape = images.shape
    clip_latent_shape = clip_latent.shape
    clip_last_hidden_shape = clip_last_hidden.shape
    imagebind_latent_shape = imagebind_latent.shape
    imagebind_hidden_shape = imagebind_hidden.shape
    break

print('Train dataloader shapes:')
print('voxels', voxels_shape)
print('images', images_shape)
print('clip_latent', clip_latent_shape)
print('clip_last_hidden', clip_last_hidden_shape)
print('imagebind_latent', imagebind_latent_shape)
print('imagebind_hidden', imagebind_hidden_shape)


0it [00:05, ?it/s]

Train dataloader shapes:
voxels torch.Size([32, 3, 39548])
images torch.Size([32, 3, 425, 425])
clip_latent torch.Size([32, 1, 768])
clip_last_hidden torch.Size([32, 1, 257, 768])
imagebind_latent torch.Size([32, 1, 1024])
imagebind_hidden torch.Size([32, 1, 257, 1280])





In [4]:
voxels_shape, images_shape, clip_latent_shape, clip_last_hidden_shape, imagebind_latent_shape, imagebind_hidden_shape = None, None, None, None, None, None
for voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden in val_dl:
    voxels_shape = voxels.shape
    images_shape = images.shape
    clip_latent_shape = clip_latent.shape
    clip_last_hidden_shape = clip_last_hidden.shape
    imagebind_latent_shape = imagebind_latent.shape
    imagebind_hidden_shape = imagebind_hidden.shape
    break

print('Train dataloader shapes:')
print('voxels', voxels_shape)
print('images', images_shape)
print('clip_latent', clip_latent_shape)
print('clip_last_hidden', clip_last_hidden_shape)
print('imagebind_latent', imagebind_latent_shape)
print('imagebind_hidden', imagebind_hidden_shape)

Train dataloader shapes:
voxels torch.Size([300, 3, 39548])
images torch.Size([300, 3, 425, 425])
clip_latent torch.Size([300, 1, 768])
clip_last_hidden torch.Size([300, 1, 257, 768])
imagebind_latent torch.Size([300, 1, 1024])
imagebind_hidden torch.Size([300, 1, 257, 1280])


In [5]:

learning_rate = 0.001
alpha = 1e-8
num_epochs = 10
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

# Define the model
class MLP(torch.nn.Module):
    def __init__(self, input_size, output_size, alpha=1.0, hidden_size=3000):
        super(MLP, self).__init__()
        self.linear_in = torch.nn.Linear(input_size, hidden_size)
        self.linear_hid = torch.nn.Linear(hidden_size, output_size)
        # self.linear_out = torch.nn.Linear(hidden_size, output_size)
        self.alpha = alpha

    def forward(self, x):
        return self.linear_hid(self.linear_in(x))

    def l2_regularization(self):
        l2_reg = 0.0
        for param in self.parameters():
            l2_reg += torch.sum(torch.pow(param, 2))
        return self.alpha * l2_reg


def train_the_model_and_get_predictions(map_name, model, save = True):
    # outputs name can be 'clip_latent', 'clip_last_hidden', 'imagebind_latent', 'imagebind_hidden'
    map_shape = None
    if map_name == 'clip_latent':
        map_shape = clip_latent_shape
    elif map_name == 'clip_last_hidden':
        map_shape = (clip_last_hidden_shape[0], clip_last_hidden_shape[1], clip_last_hidden_shape[2] * clip_last_hidden_shape[3])
    elif map_name == 'imagebind_latent':
        map_shape = imagebind_latent_shape
    elif map_name == 'imagebind_last_hidden':
        map_shape = (clip_last_hidden_shape[0], imagebind_hidden_shape[1], imagebind_hidden_shape[2] * imagebind_hidden_shape[3])
    else:
        raise Exception('Invalid map_name')
    

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    preds = []
    actuals = []
    best_loss = 1000000

    # Train the right hemisphere model
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for i, (voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden) in enumerate(train_dl):

            # Move tensors to the configured device
            labels = voxels.to(device).float().mean(dim=1)  # TODO: I'm just taking the mean, Pauls said if you use single records you get better performance

            if map_name == 'clip_latent':
                inputs = clip_latent.to(device).float().squeeze(1)
            elif map_name == 'clip_last_hidden':
                inputs = clip_last_hidden.to(device).float().squeeze(1)
                inputs = inputs.view(inputs.shape[0], -1)
            elif map_name == 'imagebind_latent':
                inputs = imagebind_latent.to(device).float().squeeze(1)
            elif map_name == 'imagebind_last_hidden':
                inputs = imagebind_hidden.to(device).float().squeeze(1)
                inputs = inputs.view(inputs.shape[0], -1)
            else:
                raise Exception('Invalid map_name')

            # Forward pass
            outputs = model(inputs)
            loss = torch.nn.functional.mse_loss(outputs, labels) + model.l2_regularization() # TODO: Try contrastive loss or any other loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()

            
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, epoch_loss / num_batches))

        test_loss = 0
        model.eval()
        with torch.no_grad():

            for i, (voxels, images, clip_latent, clip_last_hidden, imagebind_latent, imagebind_hidden) in enumerate(val_dl):
                
                # Move tensors to the configured device
                labels = voxels.to(device).float().mean(dim=1).squeeze(1) # TODO: I'm just taking the mean, Pauls said if you use single records you get better performance

                if map_name == 'clip_latent':
                    inputs = clip_latent.to(device).float().squeeze(1)
                elif map_name == 'clip_last_hidden':
                    inputs = clip_last_hidden.to(device).float().squeeze(1)
                    inputs = inputs.view(inputs.shape[0], -1)
                elif map_name == 'imagebind_latent':
                    inputs = imagebind_latent.to(device).float().squeeze(1)
                elif map_name == 'imagebind_last_hidden':
                    inputs = imagebind_hidden.to(device).float().squeeze(1)
                    inputs = inputs.view(inputs.shape[0], -1)
                else:
                    raise Exception('Invalid map_name')
                

                # Forward pass
                outputs = model(inputs)
                loss = torch.nn.functional.mse_loss(outputs, labels)
                
                test_loss += loss.item()

                if test_loss < best_loss:
                    best_loss = test_loss
                    if save:
                        torch.save(model.state_dict(), 'mlpmodels/best_model_{}.ckpt'.format(map_name))
                    preds = outputs.cpu().numpy()
                    actuals = labels.cpu().numpy()

            print('Epoch [{}/{}], Test Loss: {:.4f}'.format(epoch+1, num_epochs, test_loss / 1))

    preds = np.array(preds)
    actuals = np.array(actuals)

    return preds, actuals, model




In [6]:
# Train the model and get the predictions for final clip

learning_rate = 0.001
num_epochs = 30
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") # Change for the gpu you're using

model = MLP(clip_last_hidden_shape[-1] * clip_last_hidden_shape[-2] , voxels_shape[-1], alpha = 1e-8)
preds, actuals, model = train_the_model_and_get_predictions('clip_last_hidden', model, save = False)

y_val_right_pred = preds[:, 20544:]
y_val_left_pred = preds[:, :20544]

rh_fmri_val = actuals[:, 20544:]
lh_fmri_val = actuals[:, :20544]

# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(y_val_left_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_left_pred.shape[1])):
    lh_correlation[v] = corr(y_val_left_pred[:,v], lh_fmri_val[:,v])[0]

# Empty correlation array of shape: (RH vertices)
rh_correlation = np.zeros(y_val_right_pred.shape[1])
# Correlate each predicted RH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_right_pred.shape[1])):
    rh_correlation[v] = corr(y_val_right_pred[:,v], rh_fmri_val[:,v])[0]
    
print("Score: ", "\nRight:", rh_correlation.mean(),"\nLeft:", lh_correlation.mean())

Epoch [1/30], Loss: 0.4620
Epoch [1/30], Test Loss: 0.4479
Epoch [2/30], Loss: 0.3430
Epoch [2/30], Test Loss: 0.4568
Epoch [3/30], Loss: 0.2724
Epoch [3/30], Test Loss: 0.4498
Epoch [4/30], Loss: 0.2199
Epoch [4/30], Test Loss: 0.4400
Epoch [5/30], Loss: 0.1921
Epoch [5/30], Test Loss: 0.4387
Epoch [6/30], Loss: 0.1789
Epoch [6/30], Test Loss: 0.4444
Epoch [7/30], Loss: 0.1767
Epoch [7/30], Test Loss: 0.4657
Epoch [8/30], Loss: 0.1851
Epoch [8/30], Test Loss: 0.4649
Epoch [9/30], Loss: 0.2045
Epoch [9/30], Test Loss: 0.5216
Epoch [10/30], Loss: 0.2182
Epoch [10/30], Test Loss: 0.4690
Epoch [11/30], Loss: 0.1970
Epoch [11/30], Test Loss: 0.4977
Epoch [12/30], Loss: 0.1905
Epoch [12/30], Test Loss: 0.4871
Epoch [13/30], Loss: 0.2066
Epoch [13/30], Test Loss: 0.6574
Epoch [14/30], Loss: 0.2571
Epoch [14/30], Test Loss: 0.9533
Epoch [15/30], Loss: 0.3754
Epoch [15/30], Test Loss: 1.0909
Epoch [16/30], Loss: 0.6796
Epoch [16/30], Test Loss: 0.9593
Epoch [17/30], Loss: 1.2183
Epoch [17/30],

100%|██████████| 20544/20544 [00:01<00:00, 17612.34it/s]
100%|██████████| 19004/19004 [00:01<00:00, 18574.72it/s]

Score:  
Right: 0.38833248187880753 
Left: 0.3788003623493953





In [10]:
preds, actuals, model = train_the_model_and_get_predictions('imagebind_latent')
y_val_right_pred = preds[:, 20544:]
y_val_left_pred = preds[:, :20544]

rh_fmri_val = actuals[:, 20544:]
lh_fmri_val = actuals[:, :20544]

from tqdm import tqdm
from scipy.stats import pearsonr as corr
# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(y_val_left_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_left_pred.shape[1])):
    lh_correlation[v] = corr(y_val_left_pred[:,v], lh_fmri_val[:,v])[0]

# Empty correlation array of shape: (RH vertices)
rh_correlation = np.zeros(y_val_right_pred.shape[1])
# Correlate each predicted RH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_right_pred.shape[1])):
    rh_correlation[v] = corr(y_val_right_pred[:,v], rh_fmri_val[:,v])[0]
    
print(" ", rh_correlation.mean())

Epoch [1/38], Loss: 0.4616
(300, 39548)
Epoch [1/38], Test Loss: 0.4705
Epoch [2/38], Loss: 0.4328
(300, 39548)
Epoch [2/38], Test Loss: 0.4508
Epoch [3/38], Loss: 0.4177
(300, 39548)
Epoch [3/38], Test Loss: 0.4396
Epoch [4/38], Loss: 0.4085
(300, 39548)
Epoch [4/38], Test Loss: 0.4327
Epoch [5/38], Loss: 0.4017
(300, 39548)
Epoch [5/38], Test Loss: 0.4282
Epoch [6/38], Loss: 0.3971
(300, 39548)
Epoch [6/38], Test Loss: 0.4247
Epoch [7/38], Loss: 0.3942
(300, 39548)
Epoch [7/38], Test Loss: 0.4222
Epoch [8/38], Loss: 0.3912
(300, 39548)
Epoch [8/38], Test Loss: 0.4207
Epoch [9/38], Loss: 0.3896
(300, 39548)
Epoch [9/38], Test Loss: 0.4189
Epoch [10/38], Loss: 0.3885
(300, 39548)
Epoch [10/38], Test Loss: 0.4179
Epoch [11/38], Loss: 0.3868
(300, 39548)
Epoch [11/38], Test Loss: 0.4169
Epoch [12/38], Loss: 0.3857
(300, 39548)
Epoch [12/38], Test Loss: 0.4159
Epoch [13/38], Loss: 0.3851
(300, 39548)
Epoch [13/38], Test Loss: 0.4155
Epoch [14/38], Loss: 0.3840
(300, 39548)
Epoch [14/38], 

100%|██████████| 20544/20544 [00:02<00:00, 9140.29it/s]
100%|██████████| 19004/19004 [00:01<00:00, 9769.29it/s] 

  0.41466480265972





In [7]:
# Train the model and get the predictions
preds, actuals, model = train_the_model_and_get_predictions('clip_last_hidden')

y_val_right_pred = preds[:, 20544:]
y_val_left_pred = preds[:, :20544]

rh_fmri_val = actuals[:, 20544:]
lh_fmri_val = actuals[:, :20544]

from tqdm import tqdm
from scipy.stats import pearsonr as corr
# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(y_val_left_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_left_pred.shape[1])):
    lh_correlation[v] = corr(y_val_left_pred[:,v], lh_fmri_val[:,v])[0]

# Empty correlation array of shape: (RH vertices)
rh_correlation = np.zeros(y_val_right_pred.shape[1])
# Correlate each predicted RH vertex with the corresponding ground truth vertex
for v in tqdm(range(y_val_right_pred.shape[1])):
    rh_correlation[v] = corr(y_val_right_pred[:,v], rh_fmri_val[:,v])[0]
    
print(" ", rh_correlation.mean())

Epoch [1/38], Loss: 234.9764
(300, 39548)
Epoch [1/38], Test Loss: 21.0682
Epoch [2/38], Loss: 29.2692
(300, 39548)
Epoch [2/38], Test Loss: 0.9949
Epoch [3/38], Loss: 0.7374
(300, 39548)
Epoch [3/38], Test Loss: 0.7212
Epoch [4/38], Loss: 0.5355
(300, 39548)
Epoch [4/38], Test Loss: 0.6501
Epoch [5/38], Loss: 0.4738
(300, 39548)
Epoch [5/38], Test Loss: 0.6129
Epoch [6/38], Loss: 0.4417
(300, 39548)
Epoch [6/38], Test Loss: 0.5825
Epoch [7/38], Loss: 0.4173
(300, 39548)
Epoch [7/38], Test Loss: 0.5575
Epoch [8/38], Loss: 0.3944
(300, 39548)
Epoch [8/38], Test Loss: 0.5360
Epoch [9/38], Loss: 0.3761
(300, 39548)
Epoch [9/38], Test Loss: 0.5171
Epoch [10/38], Loss: 0.3588
(300, 39548)
Epoch [10/38], Test Loss: 0.5022
Epoch [11/38], Loss: 0.3443
(300, 39548)
Epoch [11/38], Test Loss: 0.4858
Epoch [12/38], Loss: 0.3299
(300, 39548)
Epoch [12/38], Test Loss: 0.4736
Epoch [13/38], Loss: 0.3181
(300, 39548)
Epoch [13/38], Test Loss: 0.4621
Epoch [14/38], Loss: 0.3074
(300, 39548)
Epoch [14/3

100%|██████████| 20544/20544 [00:01<00:00, 19968.88it/s]
100%|██████████| 19004/19004 [00:00<00:00, 20571.22it/s]

  0.530479848655317



