# Import packages & functions

In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import gc

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

os.chdir("/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/src")

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from models import Clipper
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
from utils import *

In [2]:
s = 1
data_type = torch.float16
data_path = "/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset"
snr_threshold = 0.65
betas = create_snr_betas(subject=s, data_type=data_type, data_path=data_path, threshold = snr_threshold)
betas.shape

torch.Size([30000, 3931])

In [None]:
#create_whole_region_unnormalized(subject = 1, include_heldout=True, mask_nsd_general=False)

In [None]:
#create_whole_region_normalized(subject = 1, include_heldout=True, mask_nsd_general=False)

In [None]:
# Stack load_nsd
current_directory = os.getcwd()
subject = 1
beta_file = f"{current_directory}/data/preprocessed_data/subject01/whole_brain_include_heldout.pt"
#beta_file = "/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/src/data/preprocessed_data/subject1/nsd_general_include_heldout.pt"
x = torch.load(beta_file).requires_grad_(False).to("cpu")
stim_descriptions = pd.read_csv("/".join(current_directory.split("/")[:-1]) + "/dataset/nsddata/experiments/nsd/nsd_stim_info_merged.csv", index_col=0)
subj_train = stim_descriptions[(stim_descriptions['subject{}'.format(subject)] != 0) & (stim_descriptions['shared1000'] == False)]
subj_test = stim_descriptions[(stim_descriptions['subject{}'.format(subject)] != 0) & (stim_descriptions['shared1000'] == True)]
test_trials = []
test_sessions = []
x_train = torch.zeros((9000, 3, x.shape[1])).to("cpu")
pbar = tqdm(desc="loading samples", total=x.shape[0])

# Collect the non-test data for the training set
for i in range(subj_train.shape[0]):
    for j in range(3):
        scanId = subj_train.iloc[i]['subject{}_rep{}'.format(subject, j)] - 1
        if(scanId < x.shape[0]):
            x_train[i, j, :] = x[scanId]
            pbar.update() 
            
            
x_train.shape
# betas = [voxels, reps, samples]
# torch.Size([9000, 3, 238508])

In [None]:
# SNR Calulation
snr, signal, noise = calculate_snr(x_train)
print(f"SNR: {snr}")
print(f"SIGNAL: {signal}")
print(f"NOISE: {noise}")

In [None]:
torch.min(snr)

In [None]:
torch.max(snr)

In [None]:
condition = snr > .35
snr_tensor = torch.where(condition, x, torch.tensor(0.0))
print(snr_tensor.shape)

snr_tensor_no_zeros = (snr_tensor != 0.0).any(dim=0)

# Filter out the zero columns
filtered_tensor = snr_tensor[:, snr_tensor_no_zeros]

print(filtered_tensor.shape)

In [None]:
filtered_tensor.shape