# Import packages & functions

In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import gc

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

os.chdir("/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/src")

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from models import Clipper
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
from utils import *

In [None]:
s = 1
data_path = ""
snr_threshold = 0.60
betas = create_snr_betas(s, data_path, threshold = snr_threshold)
betas = torch.Tensor(betas).to("cpu").to(data_type)

In [8]:
#create_whole_region_unnormalized(subject = 1, include_heldout=True, mask_nsd_general=False)

Loading raw scanning session data: 100%|██████████| 40/40 [12:20<00:00, 18.50s/it]


In [9]:
#create_whole_region_normalized(subject = 1, include_heldout=True, mask_nsd_general=False)

torch.Size([27000, 238508]) torch.Size([30000, 238508])


In [12]:
# Stack load_nsd
current_directory = os.getcwd()
subject = 1
beta_file = f"{current_directory}/data/preprocessed_data/subject01/whole_brain_include_heldout.pt"
#beta_file = "/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/src/data/preprocessed_data/subject1/nsd_general_include_heldout.pt"
x = torch.load(beta_file).requires_grad_(False).to("cpu")
stim_descriptions = pd.read_csv("/".join(current_directory.split("/")[:-1]) + "/dataset/nsddata/experiments/nsd/nsd_stim_info_merged.csv", index_col=0)
subj_train = stim_descriptions[(stim_descriptions['subject{}'.format(subject)] != 0) & (stim_descriptions['shared1000'] == False)]
subj_test = stim_descriptions[(stim_descriptions['subject{}'.format(subject)] != 0) & (stim_descriptions['shared1000'] == True)]
test_trials = []
test_sessions = []
x_train = torch.zeros((9000, 3, x.shape[1])).to("cpu")
pbar = tqdm(desc="loading samples", total=x.shape[0])

# Collect the non-test data for the training set
for i in range(subj_train.shape[0]):
    for j in range(3):
        scanId = subj_train.iloc[i]['subject{}_rep{}'.format(subject, j)] - 1
        if(scanId < x.shape[0]):
            x_train[i, j, :] = x[scanId]
            pbar.update() 
            
            
x_train.shape
# betas = [voxels, reps, samples]
# torch.Size([9000, 3, 238508])

loading samples:  89%|████████▉ | 26731/30000 [00:04<00:00, 7239.41it/s]

torch.Size([9000, 3, 238508])

In [13]:
# SNR Calulation
snr, signal, noise = calculate_snr(x_train)
print(f"SNR: {snr}")
print(f"SIGNAL: {signal}")
print(f"NOISE: {noise}")

loading samples:  90%|█████████ | 27000/30000 [00:19<00:00, 7239.41it/s]

SNR: tensor([0.3515, 0.3387, 0.3285,  ..., 0.3468, 0.3357, 0.3433])
SIGNAL: tensor([0.3452, 0.3369, 0.3301,  ..., 0.3422, 0.3349, 0.3399])
NOISE: tensor([0.9822, 0.9946, 1.0048,  ..., 0.9867, 0.9976, 0.9902])


In [4]:
torch.min(snr)

tensor(nan)

In [7]:
torch.max(snr)

tensor(2.2785)

loading samples:  90%|█████████ | 27000/30000 [00:19<00:00, 13445.59it/s]

In [42]:
condition = snr > .35
snr_tensor = torch.where(condition, x, torch.tensor(0.0))
print(snr_tensor.shape)

snr_tensor_no_zeros = (snr_tensor != 0.0).any(dim=0)

# Filter out the zero columns
filtered_tensor = snr_tensor[:, snr_tensor_no_zeros]

print(filtered_tensor)

torch.Size([30000, 238508])
tensor([[ 6.3150e-01, -6.0554e-01, -4.2650e-01,  ...,  4.3917e-01,
          1.9096e-03,  3.7853e-01],
        [ 8.3171e-01,  1.0849e-01,  3.8990e-01,  ..., -3.4662e-01,
         -1.0818e+00, -1.3847e-01],
        [ 6.6918e-01,  6.7614e-01,  1.4183e+00,  ..., -4.8620e-01,
         -6.9387e-01,  1.2925e+00],
        ...,
        [-2.2420e-01, -1.0643e+00, -3.4066e-01,  ..., -1.1105e-01,
         -2.0496e+00, -5.7585e-01],
        [ 2.1926e-01, -1.7885e+00, -6.4900e-01,  ...,  3.0040e-01,
         -1.5414e+00, -1.3365e+00],
        [ 8.7046e-01, -9.6213e-01, -1.7773e-01,  ..., -4.3053e-01,
         -7.4684e-01, -2.9001e-01]])


In [43]:
filtered_tensor.shape

torch.Size([30000, 102624])

In [None]:
train_data = {}
train_dl = {}
num_voxels = {}
voxels = {}
for s in subj_list:
    print(f"Training with {num_sessions} sessions")
    if s < 9:
        if multi_subject:
            train_url = f"{data_path}/wds/subj{s:02d}/train/" + "{0.." + f"{nsessions_allsubj[s-1]-1}" + "}.tar"
        else:
            train_url = f"{data_path}/wds/subj{s:02d}/train/" + "{0.." + f"{num_sessions-1}" + "}.tar"
        print(train_url)
        
        train_data[f'subj{s:02d}'] = wds.WebDataset(train_url,resampled=True,nodesplitter=my_split_by_node)\
                            .shuffle(750, initial=1500, rng=random.Random(42))\
                            .decode("torch")\
                            .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                            .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
        train_dl[f'subj{s:02d}'] = torch.utils.data.DataLoader(train_data[f'subj{s:02d}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
        betas = create_snr_betas(s, data_path, threshold = snr_threshold)
        betas = torch.Tensor(betas).to("cpu").to(data_type)
        num_voxels_list.append(betas[0].shape[-1])
        num_voxels[f'subj{s:02d}'] = betas[0].shape[-1]
        voxels[f'subj{s:02d}'] = betas
    elif s < 12:
        train_url = ""
        test_url = ""
        betas, images, _, _ = utils.load_imageryrf(subject=int(s-8), mode=mode, mask=True, stimtype="object", average=False, nest=False, split=True)
        betas = torch.where(torch.isnan(betas), torch.zeros_like(betas), betas)
        betas = betas.to("cpu").to(data_type)
        num_voxels_list.append(betas[0].shape[-1])
        num_voxels[f'subj{s:02d}'] = betas[0].shape[-1]
        num_nan_values = torch.sum(torch.isnan(betas))
        print("Number of NaN values in betas:", num_nan_values.item())
        indices = torch.randperm(len(betas))
        shuffled_betas = betas[indices]
        shuffled_images = images[indices]
        train_data[f'subj{s:02d}'] = torch.utils.data.TensorDataset(shuffled_betas, shuffled_images)
        train_dl[f'subj{s:02d}'] = torch.utils.data.DataLoader(train_data[f'subj{s:02d}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
        
        
    # elif s < 15:
    #     betas, images = utils.load_imageryrf(subject=int(s-11), mode="imagery", mask=True, stimtype="object", average=False, nest=False)
    #     betas = torch.where(torch.isnan(betas), torch.zeros_like(betas), betas)
    #     betas = betas.to("cpu").to(data_type)
    #     num_voxels_list.append(betas[0].shape[-1])
    #     num_voxels[f'subj{s:02d}'] = betas[0].shape[-1]
        
    #     indices = torch.randperm(len(betas))
    #     shuffled_betas = betas[indices]
    #     shuffled_images = images[indices]
    #     train_data[f'subj{s:02d}'] = torch.utils.data.TensorDataset(shuffled_betas, shuffled_images)
    #     train_dl[f'subj{s:02d}'] = torch.utils.data.DataLoader(train_data[f'subj{s:02d}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
    print(f"num_voxels for subj{s:02d}: {num_voxels[f'subj{s:02d}']}")

print("Loaded all subj train dls and betas!\n")

# Validate only on one subject (doesn't support ImageryRF)
if multi_subject: 
    subj = subj_list[0] # cant validate on the actual held out person so picking first in subj_list
if not new_test: # using old test set from before full dataset released (used in original MindEye paper)
    if subj==3:
        num_test=2113
    elif subj==4:
        num_test=1985
    elif subj==6:
        num_test=2113
    elif subj==8:
        num_test=1985
    else:
        num_test=2770
    test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
elif new_test: # using larger test set from after full dataset released
    if subj==3:
        num_test=2371
    elif subj==4:
        num_test=2188
    elif subj==6:
        num_test=2371
    elif subj==8:
        num_test=2188
    else:
        num_test=3000
    test_url = f"{data_path}/wds/subj0{subj}/new_test/" + "0.tar"
print(test_url)
if subj < 9:
    test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                        .shuffle(750, initial=1500, rng=random.Random(42))\
                        .decode("torch")\
                        .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                        .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
else:
    _, _, betas, images = utils.load_imageryrf(subject=int(subj-8), mode=mode, mask=True, stimtype="object", average=False, nest=True, split=True)
    num_test = len(betas)
    betas = torch.where(torch.isnan(betas), torch.zeros_like(betas), betas)
    betas = betas.to("cpu").to(data_type)
    num_nan_values = torch.sum(torch.isnan(betas))
    print("Number of NaN values in test betas:", num_nan_values.item())
    test_data = torch.utils.data.TensorDataset(betas, images)
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
print(f"Loaded test dl for subj{subj}!\n")

seq_len = seq_past + 1 + seq_future
print(f"currently using {seq_len} seq_len (chose {seq_past} past behav and {seq_future} future behav)")