# Import packages & functions

In [65]:
import os
import sys
import json
import io
import argparse
import numpy as np
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
from PIL import Image
import pandas as pd
import nibabel as nib
import hashlib
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
seed = 0
import utils

if utils.is_interactive():
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    %autoreload 2 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Data

### betas

In [66]:
if utils.is_interactive():
    # NSDflat_large_gsrFalse__gpFalse_visualTrue
    # NSDflat_large_gsrFalse__visualTrue_RAW
    hdf5_path = '/weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse__gpTrue_visualTrue/epoch99/test.h5'
else:
    hdf5_path = os.getenv('hdf5_path')
    
print(f"hdf5_path: {hdf5_path}")
model_name = hdf5_path.split('/test.h5')[0].split('checkpoints')[-1].replace("/","")

data_h5 = h5py.File(f'{hdf5_path}', 'r')
print(data_h5.keys())

hdf5_path: /weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse__gpTrue_visualTrue/epoch99/test.h5
<KeysViewHDF5 ['events', 'features', 'key', 'run', 'ses', 'start', 'sub']>


In [67]:
# Load integer datasets
sub = data_h5['sub'][:]        # Subject IDs
ses = data_h5['ses'][:]        # Session IDs
run = data_h5['run'][:]        # Run IDs
start = data_h5['start'][:]    # Start indices

# Load string datasets and decode them
key = data_h5['key'][:]
key = [k.decode('utf-8') if isinstance(k, bytes) else k for k in key]

# Load and deserialize the 'events' dataset
events_raw = data_h5['events'][:]
events = [json.loads(e.decode('utf-8')) if isinstance(e, bytes) else json.loads(e) for e in events_raw]

# Access the 'features' dataset without loading it into memory
features = data_h5['features']

num_TRs_per_image = 1  
TR_delay = 3            
data = [] 
image_NSD73K_indices = [] 

utils.seed_everything(seed)

for sub_val in [1]:
    print(f"\nProcessing Subject: {sub_val}")
    
    # Get indices where sub == sub_val
    indices_sub = np.where(sub == sub_val)[0]
    
    # Get unique sessions for this subject
    unique_sess = np.unique(ses[indices_sub])
    
    # Iterate over each session with a progress bar
    for sess_val in tqdm(unique_sess, desc=f"Processing sessions for sub {sub_val}"):
        # Get indices for current session
        indices_sess = indices_sub[ses[indices_sub] == sess_val]
        
        # Get unique runs within this session
        unique_runs = np.unique(run[indices_sess])
        
        # Iterate over each run
        for run_val in unique_runs:
            # Get indices for current run
            indices_run = indices_sess[run[indices_sess] == run_val]
            
            # Find events where start == 0
            indices_start0 = indices_run[start[indices_run] == 0]
            
            if len(indices_start0) == 0:
                # No events found for this run
                print(f"  Run {run_val}: No events found (start == 0). Skipping.")
                continue
            
            # Assuming events are consistent within a run, take the first occurrence
            events_list = events[indices_start0[0]]
            
            # Extract timepoints and nsd_ids from events
            timepoints = [event['index'] for event in events_list]
            nsd_ids = [event['nsd_id'] - 1 for event in events_list]  # Adjusting nsd_id as per original code
            
            # Append nsd_ids to the main list
            image_NSD73K_indices.extend(nsd_ids)
            
            # Iterate over each timepoint to extract sliding windows
            for time in timepoints:
                sliding_windows = []
                
                for i in range(num_TRs_per_image):
                    # Calculate the adjusted time with delay
                    time_ = time + i + TR_delay
                    
                    # Find the index where start == time_
                    indices_time = indices_run[start[indices_run] == time_]
                    
                    if len(indices_time) == 0:
                        # Handle missing data: Skip this sliding window
                        print(f"    Time {time_}: No feature found. Skipping this sliding window.")
                        break  # Exit the inner loop if any TR is missing
                    else:
                        # Access feature by index without loading the entire dataset
                        sliding_window = features[indices_time[0]]
                        sliding_windows.append(sliding_window)
                
                # Only append if the required number of TRs were found
                if len(sliding_windows) == num_TRs_per_image:
                    # Concatenate sliding windows if more than one TR per image
                    if num_TRs_per_image > 1:
                        sliding_window_array = np.concatenate(sliding_windows)
                    else:
                        sliding_window_array = sliding_windows[0]
                    
                    data.append(sliding_window_array)
                else:
                    print(f"    Time {time}: Incomplete sliding window. Skipping.")
    
# Convert the collected data and indices into NumPy arrays
data = np.array(data)
image_NSD73K_indices = np.array(image_NSD73K_indices)

print("\nProcessing Complete.")
print("Data Shape:", data.shape)
print("Image NSD73K Indices Length:", len(image_NSD73K_indices))

# Close the HDF5 file to free up resources
data_h5.close()


Processing Subject: 1


Processing sessions for sub 1: 100%|████████████| 40/40 [00:11<00:00,  3.47it/s]


Processing Complete.
Data Shape: (30000, 1024)
Image NSD73K Indices Length: 30000





In [68]:
n_sessions = len(np.unique(ses)) 
n_runs = len(np.unique(run))
n_TRs = np.sum((ses == 1) & (run == 2))
n_features = len(sliding_window)
print(f"n_sessions: {n_sessions}")
print(f"n_runs: {n_runs}")
print(f"n_TRs (ses=1 & run=2): {n_TRs}")
print(f"n_features: {n_features}")

n_sessions: 40
n_runs: 13
n_TRs (ses=1 & run=2): 285
n_features: 1024


### Images

In [69]:
# Load 73k NSD images
data_path = "/weka/proj-medarc/shared/mindeyev2_dataset"
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'][:]
images = torch.Tensor(images).to("cpu")

In [70]:
images = images[image_NSD73K_indices]
print(images.shape)

torch.Size([30000, 3, 224, 224])


In [71]:
vox = data
print(vox.shape)
print(images.shape)

(30000, 1024)
torch.Size([30000, 3, 224, 224])


In [72]:
shared1000_indices = np.load("/weka/proj-medarc/shared/mindeyev2_dataset/shared1000.npy")
train_or_test = np.array([not shared1000_indices[im] for im in image_NSD73K_indices])

In [73]:
all_indices = np.arange(len(images))
print(len(all_indices))
train_image_indices = all_indices[train_or_test]
test_image_indices = all_indices[~train_or_test]
print(len(train_image_indices), len(test_image_indices))

30000
27000 3000


In [74]:
# # Using Connor's old NSD-Flat that uses GLM betas, for comparison
# directory = '/weka/proj-medarc/shared/Betas-NSD-Flat/data'
# train_parquet_files = []
# test_parquet_files = []
# for filename in os.listdir(directory):
#     if (filename.endswith(".parquet")) and ("train" in filename):
#         train_parquet_files.append(os.path.join(directory, filename))
#     elif (filename.endswith(".parquet")) and ("test" in filename):
#         test_parquet_files.append(os.path.join(directory, filename))

# train_nsd_ids = []
# train_vox = []
# flatmask = None
# for file in tqdm(train_parquet_files):
#     df = pd.read_parquet(file)
#     df_filtered = df[df['subject'] == 'subj01']
#     if len(df_filtered)>0:
#         if flatmask is None:
#             flatmask = np.array(Image.open(io.BytesIO(df_filtered['activity'][0]['bytes']))) - 127
#             flatmask[flatmask!=0] = 1
        
#         train_nsd_ids.extend(df_filtered['nsd_id'].values)
#         for d in df_filtered['activity']:
#             pixels = (np.array(Image.open(io.BytesIO(d['bytes'])))[flatmask.astype(np.bool)] / 255)
#             train_vox.append(pixels)
# train_nsd_ids = np.array(train_nsd_ids)
# train_vox = np.array(train_vox)
# print("==Train==")
# print(train_nsd_ids.shape)
# print(train_vox.shape)

# test_nsd_ids = []
# test_vox = []
# for file in tqdm(test_parquet_files):
#     df = pd.read_parquet(file)
#     df_filtered = df[df['subject'] == 'subj01']
#     if len(df_filtered)>0:
#         test_nsd_ids.extend(df_filtered['nsd_id'].values)
#         for d in df_filtered['activity']:
#             pixels = (np.array(Image.open(io.BytesIO(d['bytes'])))[flatmask.astype(np.bool)] / 255)
#             test_vox.append(pixels)
# test_nsd_ids = np.array(test_nsd_ids)
# test_vox = np.array(test_vox)

# print("==Test==")
# print(test_nsd_ids.shape)
# print(test_vox.shape)

# # discard same-image repeats for test set
# unique_ids, first_indices = np.unique(test_nsd_ids, return_index=True)
# sorted_indices = np.sort(first_indices)
# test_nsd_ids = train_nsd_ids[sorted_indices]
# test_vox = train_vox[sorted_indices]

# # # group same-image repeats for test set
# # unique_ids, inverse_indices = np.unique(test_nsd_ids, return_inverse=True)
# # num_unique_ids = len(unique_ids)

# # sum_vox = np.zeros((num_unique_ids, test_vox.shape[1]))
# # counts = np.zeros(num_unique_ids)

# # np.add.at(sum_vox, inverse_indices, test_vox)
# # np.add.at(counts, inverse_indices, 1)

# # vox_aggregated = sum_vox / counts[:, np.newaxis]

# # test_nsd_ids = unique_ids
# # test_vox = vox_aggregated
# # print("   after grouping:")
# print(test_nsd_ids.shape)
# print(test_vox.shape)

# # Converting nsd_ids to actual images, and converting to torch tensors
# test_images = images[test_nsd_ids]
# images = images[train_nsd_ids]

# train_mean = np.mean(train_vox,axis=0)
# train_std = np.std(train_vox,axis=0)

# vox = utils.zscore(train_vox,train_mean=train_mean,train_std=train_std)
# test_vox = utils.zscore(test_vox,train_mean=train_mean,train_std=train_std)

# vox = torch.Tensor(vox)
# test_vox = torch.Tensor(test_vox)

# train_image_indices = np.arange(len(vox))
# test_image_indices = np.arange(len(test_vox))

# print("\n ready!")
# print(vox.shape, images.shape)
# print(test_vox.shape, test_images.shape)

# model_name = "Betas_NSD_Flat_testing"

In [75]:
train_mean = np.mean(vox[train_image_indices],axis=0)
train_std = np.std(vox[train_image_indices],axis=0)

vox = utils.zscore(vox,train_mean=train_mean,train_std=train_std)
print("inputs have been zscored according to training set")

images = torch.Tensor(images)
vox = torch.Tensor(vox)

In [76]:
# discard same-image repeats for test set
test_images_flat = images[test_image_indices].flatten(1).numpy()
hashes = [hashlib.sha256(im.tobytes()).hexdigest() for im in test_images_flat]
    
unique_ids, first_indices = np.unique(hashes, return_index=True)
sorted_indices = np.sort(first_indices)
test_images = images[sorted_indices]
test_vox = vox[sorted_indices]

# new test_image_indices
test_image_indices = np.arange(len(test_images))

print(test_images.shape, test_vox.shape)

torch.Size([1000, 3, 224, 224]) (1000, 1024)


In [77]:
# # Group same-image repeats in test set

# test_images = images[test_image_indices]
# test_vox = vox[test_image_indices]
# print(test_images.shape, test_vox.shape)

# test_images_flat = images[test_image_indices].flatten(1).numpy()
# hashes = [hashlib.sha256(im.tobytes()).hexdigest() for im in test_images_flat]
# hash_to_indices = defaultdict(list)
# for idx, img_hash in enumerate(hashes):
#     hash_to_indices[img_hash].append(idx)

In [78]:
# num_unique_images = len(hash_to_indices)
# new_test_images = torch.zeros((num_unique_images, 3, 3, images.shape[-1], images.shape[-1]))  # [1000, 3, 3, 224, 224]
# new_test_vox = torch.zeros((num_unique_images, 3, vox.shape[-1]))  # [1000, 3, vox.shape[-1]]

# # Map hashes to indices
# for new_idx, (img_hash, indices) in enumerate(hash_to_indices.items()):
#     imgs = test_images[indices]  # Shape: [3, 3, 256, 256]
#     datas = test_vox[indices]   # Shape: [3, 100]

#     # Assign to new tensors
#     new_test_images[new_idx] = imgs
#     new_test_vox[new_idx] = datas

# # Replace old with new tensors
# test_images = new_test_images
# test_vox = new_test_vox
# del new_test_images, new_test_vox
# print(test_images.shape, test_vox.shape)

# # new test_image_indices
# test_image_indices = np.arange(len(test_images))

In [79]:
### Multi-GPU config ###
from accelerate import Accelerator, DeepSpeedPlugin

local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

data_type = torch.float32 # change depending on your mixed_precision

accelerator = Accelerator(split_batches=False)# mixed_precision="fp16") # ['no', 'fp8', 'fp16', 'bf16']
if utils.is_interactive(): # set batch size here if using interactive notebook instead of submitting job
    global_batch_size = batch_size = 24
else:
    batch_size = int(os.environ["BATCH_SIZE"])
    global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]

LOCAL RANK  0


In [80]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
num_devices = torch.cuda.device_count()
print("global_batch_size", global_batch_size)
if num_devices==0 or not distributed: num_devices = 1
num_workers = num_devices
print(accelerator.state)

# set data_type to match your mixed precision (automatically set based on deepspeed config)
if accelerator.mixed_precision == "bf16":
    data_type = torch.bfloat16
elif accelerator.mixed_precision == "fp16":
    data_type = torch.float16
else:
    data_type = torch.float32

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
print = accelerator.print # only print if local_rank=0

PID of this process = 3653494
device: cuda
global_batch_size 24
Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float32


## Configurations

In [81]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    print("model_name:", model_name)
    
    # global_batch_size and batch_size should already be defined in the above cells
    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=/weka/proj-medarc/shared/mindeyev2_dataset \
                    --no-multi_subject --subj=1 --batch_size={batch_size} \
                    --hidden_dim=1024 --clip_scale=1. \
                    --no-blurry_recon --blur_scale=.5 \
                    --no-use_prior --prior_scale=30 \
                    --n_blocks=4 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=15 --no-use_image_aug \
                    --ckpt_interval=999 --no-ckpt_saving --no-wandb_log --new_test"# \
                    #--multisubject_ckpt=../../train_logs/multisubject_subj01_1024hid_nolow_300ep_seed0"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: NSDflat_large_gsrFalse__gpTrue_visualTrueepoch99
--data_path=/weka/proj-medarc/shared/mindeyev2_dataset                     --no-multi_subject --subj=1 --batch_size=24                     --hidden_dim=1024 --clip_scale=1.                     --no-blurry_recon --blur_scale=.5                     --no-use_prior --prior_scale=30                     --n_blocks=4 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=15 --no-use_image_aug                     --ckpt_interval=999 --no-ckpt_saving --no-wandb_log --new_test
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [82]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--data_path", type=str, default="/weka/proj-fmri/shared/natural-scenes-dataset",
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
    help="Validate on which subject?",
)
parser.add_argument(
    "--multisubject_ckpt", type=str, default=None,
    help="Path to pre-trained multisubject model to finetune a single subject from. multisubject must be False.",
)
parser.add_argument(
    "--num_sessions", type=int, default=0,
    help="Number of training sessions to include (if multi_subject, this variable doesnt matter)",
)
parser.add_argument(
    "--use_prior",action=argparse.BooleanOptionalAction,default=False,
    help="whether to train diffusion prior (True) or just rely on retrieval part of the pipeline (False)",
)
parser.add_argument(
    "--batch_size", type=int, default=32,
    help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--low_mem",action=argparse.BooleanOptionalAction,default=False,
    help="whether to preload images to cpu to speed things up but consume more memory",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--blur_scale",type=float,default=.5,
    help="multiply loss from blurry recons by this number",
)
parser.add_argument(
    "--clip_scale",type=float,default=1.,
    help="multiply contrastive loss by this number",
)
parser.add_argument(
    "--prior_scale",type=float,default=30,
    help="multiply diffusion prior loss by this",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=120,
    help="number of epochs of training",
)
parser.add_argument(
    "--multi_subject",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=2,
)
parser.add_argument(
    "--hidden_dim",type=int,default=1024,
)
parser.add_argument(
    "--seq_past",type=int,default=0,
)
parser.add_argument(
    "--seq_future",type=int,default=0,
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../../train_logs/{model_name}')
if not os.path.exists(outdir) and ckpt_saving:
    os.makedirs(outdir,exist_ok=True)
    
if use_image_aug or blurry_recon:
    import kornia
    import kornia.augmentation as K
    from kornia.augmentation.container import AugmentationSequential
if use_image_aug:
    img_augment = AugmentationSequential(
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.3),
        same_on_batch=False,
        data_keys=["input"],
    )
    # Define the blurring augmentations
    blur_augment = K.RandomGaussianBlur(kernel_size=(21, 21), sigma=(51.0, 51.0), p=1.)
    
if multi_subject:
    subj_list = np.arange(1,9)
    subj_list = subj_list[subj_list != subj]
else:
    subj_list = [subj]

print("subj_list", subj_list, "num_sessions", num_sessions)

subj_list [1] num_sessions 0


## Prep data, models, and dataloaders

### Creating wds dataloader, preload betas and all 73k possible images

In [83]:
def my_split_by_node(urls): return urls
num_voxels_list = []

if multi_subject:
    nsessions_allsubj=np.array([40, 40, 32, 30, 40, 32, 40, 30])
    num_samples_per_epoch = (750*40) // num_devices 
else:
    # num_samples_per_epoch = (750*num_sessions) // num_devices 
    num_samples_per_epoch = len(train_image_indices)

print("dividing batch size by subj_list, which will then be concatenated across subj during training...") 
batch_size = batch_size // len(subj_list)

num_iterations_per_epoch = num_samples_per_epoch // (batch_size*len(subj_list))

print("batch_size =", batch_size, "num_iterations_per_epoch =",num_iterations_per_epoch, "num_samples_per_epoch =",num_samples_per_epoch)

dividing batch size by subj_list, which will then be concatenated across subj during training...
batch_size = 24 num_iterations_per_epoch = 1125 num_samples_per_epoch = 27000


In [84]:
train_data = {}
train_dl = {}

train_data[f'subj0{subj}'] = torch.utils.data.TensorDataset(torch.tensor(train_image_indices))

test_data = torch.utils.data.TensorDataset(torch.tensor(test_image_indices))

In [85]:
num_voxels = {}
voxels = {}
for s in subj_list:
    print(f"Training with {num_sessions} sessions")
    train_dl = torch.utils.data.DataLoader(train_data[f'subj0{s}'], batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)

    num_voxels_list.append(vox[0].shape[-1])
    num_voxels[f'subj0{s}'] = vox[0].shape[-1]
    voxels[f'subj0{s}'] = vox
    print(f"num_voxels for subj0{s}: {num_voxels[f'subj0{s}']}")

print("Loaded all subj train dls and vox!\n")

# Validate only on one subject
if multi_subject: 
    subj = subj_list[0] # cant validate on the actual held out person so picking first in subj_list
test_dl = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False, drop_last=True, pin_memory=True)

print(f"Loaded test dl for subj{subj}!\n")

Training with 0 sessions
num_voxels for subj01: 1024
Loaded all subj train dls and vox!

Loaded test dl for subj1!



## Load models

### CLIP image embeddings  model

In [86]:
## USING OpenCLIP ViT-bigG ###
sys.path.append('mindeye_utils/')
import mindeye_utils.generative_models.sgm
from mindeye_utils.generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder
from mindeye_utils.generative_models.sgm.models.diffusion import DiffusionEngine
from omegaconf import OmegaConf

try:
    print(clip_img_embedder)
except:
    # # last hidden
    # clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    #     arch="ViT-bigG-14",
    #     version="laion2b_s39b_b160k",
    #     output_tokens=True,
    #     only_tokens=True,
    # )
    # clip_img_embedder.to(device)
    # clip_seq_dim = 256
    # clip_emb_dim = 1664

    # final
    clip_img_embedder = FrozenOpenCLIPImageEmbedder(
        arch="ViT-bigG-14",
        version="laion2b_s39b_b160k",
        output_tokens=False,
        only_tokens=False,
    )
    clip_img_embedder.to(device)
    clip_seq_dim = 1
    clip_emb_dim = 1280

# ## USING OPEN AI CLIP ViT-L ###
# import clip
# try:
#     print(clip_model)
# except:
#     clip_model, preprocess = clip.load("ViT-L/14", device=device)
#     preprocess = transforms.Compose([
#         transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
#         transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
#                              std=[0.26862954, 0.26130258, 0.27577711]),
#     ])
# def clip_img_embedder(image):
#     preproc_img = preprocess(image)
#     return clip_model.encode_image(preproc_img)
# clip_seq_dim = 1
# clip_emb_dim = 768

FrozenOpenCLIPImageEmbedder(
  (model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 1664, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (patch_dropout): Identity()
      (ln_pre): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-47): 48 x ResidualAttentionBlock(
            (ln_1): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1664, out_features=1664, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=1664, out_features=8192, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=8192, out_features=1664, bias=True)
            )
            (ls_2): Identity()


### MindEye modules

In [87]:
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x
        
model = MindEyeModule()
model

MindEyeModule()

In [88]:
class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer
    def __init__(self, input_sizes, out_features, seq_len=1): 
        super(RidgeRegression, self).__init__()
        self.seq_len = seq_len
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    def forward(self, x, subj_idx=0):
        out = torch.cat([self.linears[subj_idx](x[:,seq]).unsqueeze(1) for seq in range(self.seq_len)], dim=1)
        return out
        
model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim)
utils.count_params(model.ridge)
utils.count_params(model)

# test on subject 1 with fake data
b = torch.randn((2,1,num_voxels_list[0]))
print(b.shape, model.ridge(b,0).shape)

param counts:
1,049,600 total
1,049,600 trainable
param counts:
1,049,600 total
1,049,600 trainable
torch.Size([2, 1, 1024]) torch.Size([2, 1, 1024])


In [89]:
from functools import partial
class BrainNetwork(nn.Module):
    def __init__(self, h=4096, in_dim=15724, out_dim=768, seq_len=1, n_blocks=n_blocks, drop=.15, 
                 clip_size=768):
        super().__init__()
        self.seq_len = seq_len
        self.h = h
        self.clip_size = clip_size
        
        self.mixer_blocks1 = nn.ModuleList([
            self.mixer_block1(h, drop) for _ in range(n_blocks)
        ])
        self.mixer_blocks2 = nn.ModuleList([
            self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
        ])
        
        # Output linear layer
        self.backbone_linear = nn.Linear(h * seq_len, out_dim, bias=True) 
        if clip_scale>0:
            self.clip_proj = self.projector(clip_size, clip_size, h=clip_size)
            
    def projector(self, in_dim, out_dim, h=2048):
        return nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.GELU(),
            nn.Linear(in_dim, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Linear(h, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Linear(h, out_dim)
        )
    
    def mlp(self, in_dim, out_dim, drop):
        return nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(out_dim, out_dim),
        )
    
    def mixer_block1(self, h, drop):
        return nn.Sequential(
            nn.LayerNorm(h),
            self.mlp(h, h, drop),  # Token mixing
        )

    def mixer_block2(self, seq_len, drop):
        return nn.Sequential(
            nn.LayerNorm(seq_len),
            self.mlp(seq_len, seq_len, drop)  # Channel mixing
        )
        
    def forward(self, x):
        # make empty tensors
        c,b = torch.Tensor([0.]), torch.Tensor([[0.],[0.]])
        
        # Mixer blocks
        residual1 = x
        residual2 = x.permute(0,2,1)
        for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
            x = block1(x) + residual1
            residual1 = x
            x = x.permute(0,2,1)
            
            x = block2(x) + residual2
            residual2 = x
            x = x.permute(0,2,1)
            
        x = x.reshape(x.size(0), -1)
        backbone = self.backbone_linear(x).reshape(len(x), -1, self.clip_size)
        if clip_scale>0:
            c = self.clip_proj(backbone)
        
        return backbone, c, b

model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=1, 
                          clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim)
utils.count_params(model.backbone)
utils.count_params(model)

# test that the model works on some fake data
b = torch.randn((2,1,hidden_dim))
print("b.shape",b.shape)

backbone_, clip_, blur_ = model.backbone(b)
print(backbone_.shape, clip_.shape, blur_[0].shape, blur_[1].shape)

param counts:
14,643,736 total
14,643,736 trainable
param counts:
15,693,336 total
15,693,336 trainable
b.shape torch.Size([2, 1, 1024])
torch.Size([2, 1, 1280]) torch.Size([2, 1, 1280]) torch.Size([1]) torch.Size([1])


### Adding diffusion prior + unCLIP if use_prior=True

In [90]:
if use_prior:
    from models import *

    # setup diffusion prior network
    out_dim = clip_emb_dim
    depth = 6
    dim_head = 52
    heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim
    timesteps = 100

    prior_network = PriorNetwork(
            dim=out_dim,
            depth=depth,
            dim_head=dim_head,
            heads=heads,
            causal=False,
            num_tokens = clip_seq_dim,
            learned_query_mode="pos_emb"
        )

    model.diffusion_prior = BrainDiffusionPrior(
        net=prior_network,
        image_embed_dim=out_dim,
        condition_on_text_encodings=False,
        timesteps=timesteps,
        cond_drop_prob=0.2,
        image_embed_scale=None,
    )
    
    utils.count_params(model.diffusion_prior)
    utils.count_params(model)

### Setup optimizer / lr / ckpt saving

In [58]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

opt_grouped_parameters = [
    {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
# model.backbone.requires_grad_(False)

if use_prior:
    opt_grouped_parameters.extend([
        {'params': [p for n, p in model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
        {'params': [p for n, p in model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ])

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    if num_iterations_per_epoch==0:
        num_iterations_per_epoch=1
    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))
    print("total_steps", total_steps)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )
    
def save_ckpt(tag):
    ckpt_path = outdir+f'/{tag}.pth'
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        torch.save({
            'epoch': epoch,
            'model_state_dict': unwrapped_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'train_losses': losses,
            'test_losses': test_losses,
            'lrs': lrs,
            }, ckpt_path)
    print(f"\n---saved {outdir}/{tag} ckpt!---\n")

def load_ckpt(tag,load_lr=True,load_optimizer=True,load_epoch=True,strict=True,outdir=outdir,multisubj_loading=False): 
    print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
    checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    if multisubj_loading: # remove incompatible ridge layer that will otherwise error
        state_dict.pop('ridge.linears.0.weight',None)
    model.load_state_dict(state_dict, strict=strict)
    if load_epoch:
        globals()["epoch"] = checkpoint['epoch']
        print("Epoch",epoch)
    if load_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if load_lr:
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    del checkpoint

print("\nDone with model preparations!")
num_params = utils.count_params(model)

total_steps 16875

Done with model preparations!
param counts:
88,110,616 total
88,110,616 trainable


# WandB

In [59]:
if utils.is_interactive():
    print("Running inside interactive notebook. Disabling wandb and ckpt saving...")
    wandb_log = False
    ckpt_saving = False
if local_rank==0 and wandb_log: # only use main process for wandb logging
    import wandb
    wandb_project = 'found_mindeye'
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_sessions": num_sessions,
      "num_params": num_params,
      "clip_scale": clip_scale,
      "prior_scale": prior_scale,
      "blur_scale": blur_scale,
      "use_image_aug": use_image_aug,
      "max_lr": max_lr,
      "mixup_pct": mixup_pct,
      "num_samples_per_epoch": num_samples_per_epoch,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume="allow",
    )
else:
    wandb_log = False

Running inside interactive notebook. Disabling wandb and ckpt saving...


# Train the model

In [60]:
epoch = 0
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()

In [61]:
# load multisubject stage1 ckpt if set
if multisubject_ckpt is not None and not resume_from_ckpt:
    load_ckpt("last",outdir=multisubject_ckpt,load_lr=False,load_optimizer=False,load_epoch=False,strict=False,multisubj_loading=True)

In [62]:
# train_dls = [train_dl[f'subj0{s}'] for s in subj_list]

model, optimizer, train_dl, lr_scheduler = accelerator.prepare(model, optimizer, train_dl, lr_scheduler)
# leaving out test_dl since we will only have local_rank 0 device do evals

In [63]:
print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
mse = nn.MSELoss()
l1 = nn.L1Loss()
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

for epoch in tqdm(range(epoch,num_epochs)):
    model.train()

    fwd_percent_correct = 0.
    bwd_percent_correct = 0.
    test_fwd_percent_correct = 0.
    test_bwd_percent_correct = 0.
    
    recon_cossim = 0.
    test_recon_cossim = 0.
    recon_mse = 0.
    test_recon_mse = 0.

    loss_clip_total = 0.
    loss_blurry_total = 0.
    loss_blurry_cont_total = 0.
    test_loss_clip_total = 0.
    
    loss_prior_total = 0.
    test_loss_prior_total = 0.

    blurry_pixcorr = 0.
    test_blurry_pixcorr = 0. 

    # you now have voxel_iters and image_iters with num_iterations_per_epoch batches each
    for train_i, behav in enumerate(train_dl):  
        with torch.cuda.amp.autocast(dtype=data_type):
            optimizer.zero_grad()
            loss = 0.
            
            behav = behav[0]

            image = images[behav.long().cpu()].to(device)
            voxel = vox[behav.long().cpu()]
            voxel = torch.Tensor(voxel).unsqueeze(1).to(device)

            if use_image_aug: 
                image = img_augment(image)

            clip_target = clip_img_embedder(image)
            if clip_target.ndim == 2: clip_target = clip_target.unsqueeze(1)
            assert not torch.any(torch.isnan(clip_target))

            if epoch < int(mixup_pct * num_epochs):
                voxel, perm, betas, select = utils.mixco(voxel)

            voxel_ridge = model.ridge(voxel,0) #[model.ridge(voxel_list[si],si) for si,s in enumerate(subj_list)]
            # voxel_ridge = torch.cat(voxel_ridge_list, dim=0)

            backbone, clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) #voxel)#voxel_ridge)

            if clip_scale>0:
                clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)

            if use_prior:
                loss_prior, prior_out = model.diffusion_prior(text_embed=backbone, image_embed=clip_target)
                loss_prior_total += loss_prior.item()
                loss_prior *= prior_scale
                loss += loss_prior

                recon_cossim += nn.functional.cosine_similarity(prior_out, clip_target).mean().item()
                recon_mse += mse(prior_out, clip_target).item()

            if clip_scale>0:
                if epoch < int(mixup_pct * num_epochs):                
                    loss_clip = utils.mixco_nce(
                        clip_voxels_norm,
                        clip_target_norm,
                        temp=.006,
                        perm=perm, betas=betas, select=select)
                else:
                    epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                    loss_clip = utils.soft_clip_loss(
                        clip_voxels_norm,
                        clip_target_norm,
                        temp=epoch_temp)

                loss_clip_total += loss_clip.item()
                loss_clip *= clip_scale
                loss += loss_clip

            if blurry_recon:     
                image_enc_pred, transformer_feats = blurry_image_enc_

                image_enc = autoenc.encode(2*image-1).latent_dist.mode() * 0.18215
                loss_blurry = l1(image_enc_pred, image_enc)
                loss_blurry_total += loss_blurry.item()

                if epoch < int(mixup_pct * num_epochs):
                    image_enc_shuf = image_enc[perm]
                    betas_shape = [-1] + [1]*(len(image_enc.shape)-1)
                    image_enc[select] = image_enc[select] * betas[select].reshape(*betas_shape) + \
                        image_enc_shuf[select] * (1 - betas[select]).reshape(*betas_shape)

                image_norm = (image - mean)/std
                image_aug = (blur_augs(image) - mean)/std
                _, cnx_embeds = cnx(image_norm)
                _, cnx_aug_embeds = cnx(image_aug)

                cont_loss = utils.soft_cont_loss(
                    nn.functional.normalize(transformer_feats.reshape(-1, transformer_feats.shape[-1]), dim=-1),
                    nn.functional.normalize(cnx_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),
                    nn.functional.normalize(cnx_aug_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),
                    temp=0.2)
                loss_blurry_cont_total += cont_loss.item()

                loss += (loss_blurry + 0.1*cont_loss) * blur_scale #/.18215

            if clip_scale>0:
                # forward and backward top 1 accuracy        
                labels = torch.arange(len(clip_voxels_norm)).to(clip_voxels_norm.device) 
                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()

            if blurry_recon:
                with torch.no_grad():
                    # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()
                    random_samps = np.random.choice(np.arange(len(image)), size=len(image)//5, replace=False)
                    blurry_recon_images = (autoenc.decode(image_enc_pred[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)
                    pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
                    blurry_pixcorr += pixcorr.item()
            
            utils.check_loss(loss)
            accelerator.backward(loss)
            optimizer.step()

            losses.append(loss.item())
            lrs.append(optimizer.param_groups[0]['lr'])

            if lr_scheduler_type is not None:
                lr_scheduler.step()
                
            if train_i >= num_iterations_per_epoch-1:
                break
                
    model.eval()
    if local_rank==0:
        with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):   
            for test_i, behav in enumerate(test_dl):
                loss=0.
                behav = behav[0]
            
                # image = test_images[behav.long().cpu()][:,0].to(device)
                # voxel = test_vox[behav.long().cpu()].mean(1)

                image = test_images[behav.long().cpu()].to(device)
                voxel = test_vox[behav.long().cpu()]
                    
                voxel = torch.Tensor(voxel).unsqueeze(1).to(device)
            
                assert len(image) == 1000 #300
            
                clip_target = clip_img_embedder(image.float())
                if clip_target.ndim == 2: clip_target = clip_target.unsqueeze(1)
            
                voxel_ridge = model.ridge(voxel,0)
    
                backbone, clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)#voxel) #voxel_ridge)
    
                if clip_scale>0:
                    clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
                    clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
                
                # for some evals, only doing a subset of the samples per batch because of computational cost
                if use_prior or blurry_recon:
                    random_samps = np.random.choice(np.arange(len(image)), size=len(image)//5, replace=False)
                
                if use_prior:
                    loss_prior, contaminated_prior_out = model.diffusion_prior(text_embed=backbone[random_samps], image_embed=clip_target[random_samps])
                    test_loss_prior_total += loss_prior.item()
                    loss_prior *= prior_scale
                    loss += loss_prior
                        
                if clip_scale>0:
                    loss_clip = utils.soft_clip_loss(
                        clip_voxels_norm,
                        clip_target_norm,
                        temp=.006)
    
                    test_loss_clip_total += loss_clip.item()
                    loss_clip = loss_clip * clip_scale
                    loss += loss_clip
    
                if blurry_recon:
                    image_enc_pred, _ = blurry_image_enc_
                    blurry_recon_images = (autoenc.decode(image_enc_pred[random_samps]/0.18215).sample / 2 + 0.5).clamp(0,1)
                    pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
                    test_blurry_pixcorr += pixcorr.item()
    
                if clip_scale>0:
                    # forward and backward top 1 accuracy        
                    labels = torch.arange(len(clip_voxels_norm)).to(clip_voxels_norm.device) 
                    test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
                    test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()
                
                utils.check_loss(loss)                
                test_losses.append(loss.item())

            # if utils.is_interactive(): clear_output(wait=True)
            print("---")

            # assert (test_i+1) == 1
            logs = {"train/loss": np.mean(losses[-(train_i+1):]),
                "test/loss": np.mean(test_losses[-(test_i+1):]),
                "train/lr": lrs[-1],
                "train/num_steps": len(losses),
                "test/num_steps": len(test_losses),
                "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
                "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
                "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
                "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
                "train/loss_clip_total": loss_clip_total / (train_i + 1),
                "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
                "train/loss_blurry_cont_total": loss_blurry_cont_total / (train_i + 1),
                "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
                "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
                "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
                "train/recon_cossim": recon_cossim / (train_i + 1),
                "test/recon_cossim": test_recon_cossim / (test_i + 1),
                "train/recon_mse": recon_mse / (train_i + 1),
                "test/recon_mse": test_recon_mse / (test_i + 1),
                "train/loss_prior": loss_prior_total / (train_i + 1),
                "test/loss_prior": test_loss_prior_total / (test_i + 1),
                }

            # if finished training, save jpg recons if they exist
            if (epoch == num_epochs-1) or (epoch % ckpt_interval == 0):
                if blurry_recon:    
                    image_enc = autoenc.encode(2*image[:4]-1).latent_dist.mode() * 0.18215
                    # transform blurry recon latents to images and plot it
                    fig, axes = plt.subplots(1, 8, figsize=(10, 4))
                    jj=-1
                    for j in [0,1,2,3]:
                        jj+=1
                        axes[jj].imshow(utils.torch_to_Image((autoenc.decode(image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
                        axes[jj].axis('off')
                        jj+=1
                        axes[jj].imshow(utils.torch_to_Image((autoenc.decode(image_enc_pred[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
                        axes[jj].axis('off')

                    plt.show()

            print(logs)

            if wandb_log: wandb.log(logs)
            
    # Save model checkpoint and reconstruct
    if (ckpt_saving) and (epoch % ckpt_interval == 0):
        save_ckpt(f'last')

    # wait for other GPUs to catch up if needed
    accelerator.wait_for_everyone()
    torch.cuda.empty_cache()

print(f"{model_name}")
print("\n===Finished!===\n")
if ckpt_saving:
    save_ckpt(f'last')

NSDflat_large_gsrFalse__visualTrue_RAWepoch99 starting with epoch 0 / 15


  with torch.cuda.amp.autocast(dtype=data_type):
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):
  7%|██▊                                        | 1/15 [04:14<59:29, 254.93s/it]

---
{'train/loss': np.float64(2.791205801539951), 'test/loss': np.float64(4.756008148193359), 'train/lr': 0.00015589942434296661, 'train/num_steps': 1125, 'test/num_steps': 1, 'train/fwd_pct_correct': 0.23281482119692695, 'train/bwd_pct_correct': 0.19677778357598516, 'test/test_fwd_pct_correct': 0.11800000816583633, 'test/test_bwd_pct_correct': 0.05000000074505806, 'train/loss_clip_total': 2.791205801539951, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 4.756008148193359, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 13%|█████▋                                     | 2/15 [08:30<55:21, 255.53s/it]

---
{'train/loss': np.float64(1.4240021441777546), 'test/loss': np.float64(4.202593803405762), 'train/lr': 0.0003, 'train/num_steps': 2250, 'test/num_steps': 2, 'train/fwd_pct_correct': 0.4621481602589289, 'train/bwd_pct_correct': 0.40488889989587995, 'test/test_fwd_pct_correct': 0.16200000047683716, 'test/test_bwd_pct_correct': 0.08700000494718552, 'train/loss_clip_total': 1.4240021441777546, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 4.202593803405762, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 20%|████████▌                                  | 3/15 [12:45<51:00, 255.03s/it]

---
{'train/loss': np.float64(1.0604153644243877), 'test/loss': np.float64(3.6940548419952393), 'train/lr': 0.0002956414469630032, 'train/num_steps': 3375, 'test/num_steps': 3, 'train/fwd_pct_correct': 0.5515926084253523, 'train/bwd_pct_correct': 0.5193703854613834, 'test/test_fwd_pct_correct': 0.22700001299381256, 'test/test_bwd_pct_correct': 0.14100000262260437, 'train/loss_clip_total': 1.0604153644243877, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 3.6940548419952393, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 27%|███████████▍                               | 4/15 [17:02<46:56, 256.01s/it]

---
{'train/loss': np.float64(0.9127149018181695), 'test/loss': np.float64(3.3637735843658447), 'train/lr': 0.0002828190911118275, 'train/num_steps': 4500, 'test/num_steps': 4, 'train/fwd_pct_correct': 0.5798518682718277, 'train/bwd_pct_correct': 0.5678889056576623, 'test/test_fwd_pct_correct': 0.2770000100135803, 'test/test_bwd_pct_correct': 0.17400000989437103, 'train/loss_clip_total': 0.9127149018181695, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 3.3637735843658447, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 33%|██████████████▎                            | 5/15 [21:13<42:19, 253.94s/it]

---
{'train/loss': np.float64(0.5830798175070021), 'test/loss': np.float64(3.328577756881714), 'train/lr': 0.0002622781211611761, 'train/num_steps': 5625, 'test/num_steps': 5, 'train/fwd_pct_correct': 0.8310370570818583, 'train/bwd_pct_correct': 0.7681852051417033, 'test/test_fwd_pct_correct': 0.2850000262260437, 'test/test_bwd_pct_correct': 0.16100001335144043, 'train/loss_clip_total': 0.5830798175070021, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 3.328577756881714, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 40%|█████████████████▏                         | 6/15 [25:25<37:59, 253.32s/it]

---
{'train/loss': np.float64(0.4280092833373282), 'test/loss': np.float64(3.0443568229675293), 'train/lr': 0.00023521230362119294, 'train/num_steps': 6750, 'test/num_steps': 6, 'train/fwd_pct_correct': 0.8765926127433776, 'train/bwd_pct_correct': 0.8232222409778172, 'test/test_fwd_pct_correct': 0.3150000274181366, 'test/test_bwd_pct_correct': 0.2290000170469284, 'train/loss_clip_total': 0.4280092833373282, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 3.0443568229675293, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 47%|████████████████████                       | 7/15 [29:41<33:54, 254.28s/it]

---
{'train/loss': np.float64(0.33272055599093436), 'test/loss': np.float64(2.7025370597839355), 'train/lr': 0.00020319460542705803, 'train/num_steps': 7875, 'test/num_steps': 7, 'train/fwd_pct_correct': 0.9050370572937859, 'train/bwd_pct_correct': 0.8607407600084941, 'test/test_fwd_pct_correct': 0.3500000238418579, 'test/test_bwd_pct_correct': 0.2510000169277191, 'train/loss_clip_total': 0.33272055599093436, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 2.7025370597839355, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 53%|██████████████████████▉                    | 8/15 [34:00<29:50, 255.81s/it]

---
{'train/loss': np.float64(0.2585171893917852), 'test/loss': np.float64(2.4041671752929688), 'train/lr': 0.00016808577881821687, 'train/num_steps': 9000, 'test/num_steps': 8, 'train/fwd_pct_correct': 0.9225185389518737, 'train/bwd_pct_correct': 0.8917777981228299, 'test/test_fwd_pct_correct': 0.4050000309944153, 'test/test_bwd_pct_correct': 0.28600001335144043, 'train/loss_clip_total': 0.2585171893917852, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 2.4041671752929688, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 60%|█████████████████████████▊                 | 9/15 [38:13<25:29, 254.88s/it]

---
{'train/loss': np.float64(0.1952740568055047), 'test/loss': np.float64(2.1908323764801025), 'train/lr': 0.00013192622118178308, 'train/num_steps': 10125, 'test/num_steps': 9, 'train/fwd_pct_correct': 0.9409259461826748, 'train/bwd_pct_correct': 0.9189629832373725, 'test/test_fwd_pct_correct': 0.45000001788139343, 'test/test_bwd_pct_correct': 0.3330000042915344, 'train/loss_clip_total': 0.1952740568055047, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 2.1908323764801025, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 67%|████████████████████████████              | 10/15 [42:34<21:23, 256.69s/it]

---
{'train/loss': np.float64(0.1371368974811501), 'test/loss': np.float64(1.8875792026519775), 'train/lr': 9.681739457294188e-05, 'train/num_steps': 11250, 'test/num_steps': 10, 'train/fwd_pct_correct': 0.9594815004666646, 'train/bwd_pct_correct': 0.9402222423553467, 'test/test_fwd_pct_correct': 0.4880000352859497, 'test/test_bwd_pct_correct': 0.4280000329017639, 'train/loss_clip_total': 0.1371368974811501, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 1.8875792026519775, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 73%|██████████████████████████████▊           | 11/15 [46:46<17:01, 255.37s/it]

---
{'train/loss': np.float64(0.09120522881568306), 'test/loss': np.float64(1.761939525604248), 'train/lr': 6.479969637880702e-05, 'train/num_steps': 12375, 'test/num_steps': 11, 'train/fwd_pct_correct': 0.9729629787868923, 'train/bwd_pct_correct': 0.9605926111539205, 'test/test_fwd_pct_correct': 0.5410000085830688, 'test/test_bwd_pct_correct': 0.445000022649765, 'train/loss_clip_total': 0.09120522881568306, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 1.761939525604248, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 80%|█████████████████████████████████▌        | 12/15 [50:57<12:42, 254.10s/it]

---
{'train/loss': np.float64(0.06175109098168711), 'test/loss': np.float64(1.6575655937194824), 'train/lr': 3.773387883882384e-05, 'train/num_steps': 13500, 'test/num_steps': 12, 'train/fwd_pct_correct': 0.9823333454661899, 'train/bwd_pct_correct': 0.9739259412553576, 'test/test_fwd_pct_correct': 0.5700000524520874, 'test/test_bwd_pct_correct': 0.47600001096725464, 'train/loss_clip_total': 0.06175109098168711, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 1.6575655937194824, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 87%|████████████████████████████████████▍     | 13/15 [55:07<08:25, 252.92s/it]

---
{'train/loss': np.float64(0.04394651980925765), 'test/loss': np.float64(1.4168270826339722), 'train/lr': 1.7192908888172444e-05, 'train/num_steps': 14625, 'test/num_steps': 13, 'train/fwd_pct_correct': 0.9875185278786553, 'train/bwd_pct_correct': 0.9818518639140659, 'test/test_fwd_pct_correct': 0.6270000338554382, 'test/test_bwd_pct_correct': 0.5300000309944153, 'train/loss_clip_total': 0.04394651980925765, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 1.4168270826339722, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


 93%|███████████████████████████████████████▏  | 14/15 [59:21<04:13, 253.00s/it]

---
{'train/loss': np.float64(0.03444159144993561), 'test/loss': np.float64(1.3119802474975586), 'train/lr': 4.370553036996754e-06, 'train/num_steps': 15750, 'test/num_steps': 14, 'train/fwd_pct_correct': 0.9907407484584384, 'train/bwd_pct_correct': 0.9855185293091668, 'test/test_fwd_pct_correct': 0.6530000567436218, 'test/test_bwd_pct_correct': 0.5630000233650208, 'train/loss_clip_total': 0.03444159144993561, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 1.3119802474975586, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}


100%|████████████████████████████████████████| 15/15 [1:03:29<00:00, 254.00s/it]

---
{'train/loss': np.float64(0.028494710391490825), 'test/loss': np.float64(1.2958446741104126), 'train/lr': 1.1999999999999998e-08, 'train/num_steps': 16875, 'test/num_steps': 15, 'train/fwd_pct_correct': 0.9926296353340149, 'train/bwd_pct_correct': 0.9884444528685675, 'test/test_fwd_pct_correct': 0.6540000438690186, 'test/test_bwd_pct_correct': 0.5680000185966492, 'train/loss_clip_total': 0.028494710391490825, 'train/loss_blurry_total': 0.0, 'train/loss_blurry_cont_total': 0.0, 'test/loss_clip_total': 1.2958446741104126, 'train/blurry_pixcorr': 0.0, 'test/blurry_pixcorr': 0.0, 'train/recon_cossim': 0.0, 'test/recon_cossim': 0.0, 'train/recon_mse': 0.0, 'test/recon_mse': 0.0, 'train/loss_prior': 0.0, 'test/loss_prior': 0.0}
NSDflat_large_gsrFalse__visualTrue_RAWepoch99

===Finished!===




