In [8]:
import os
import sys
import math
import random
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
import torchvision
from torchvision import transforms
from tqdm import tqdm
import PIL
from datetime import datetime
import h5py

import kornia
from kornia.augmentation.container import AugmentationSequential


import webdataset as wds
from info_nce import InfoNCE
import clip
import pandas as pd
from collections import OrderedDict

from utils import *
from model import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:",device)

def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

device: cpu


In [9]:
sd = torch.load('checkpoints/clip_image_vitB_large_768bs_subj01_best.pth', map_location=torch.device('cpu'))

In [10]:
sd.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'train_losses', 'val_losses', 'train_topk', 'val_topk', 'lrs'])

In [11]:
sd['model_state_dict'].keys()

odict_keys(['conv.0.weight', 'conv.0.bias', 'lins.0.1.weight', 'lins.0.1.bias', 'lins.0.3.weight', 'lins.0.3.bias', 'lins.0.3.running_mean', 'lins.0.3.running_var', 'lins.0.3.num_batches_tracked', 'lins.0.5.weight', 'lins.0.5.bias', 'lins.0.7.weight', 'lins.0.7.bias', 'lins.0.7.running_mean', 'lins.0.7.running_var', 'lins.0.7.num_batches_tracked', 'lins.1.1.weight', 'lins.1.1.bias', 'lins.1.3.weight', 'lins.1.3.bias', 'lins.1.3.running_mean', 'lins.1.3.running_var', 'lins.1.3.num_batches_tracked', 'lins.1.5.weight', 'lins.1.5.bias', 'lins.1.7.weight', 'lins.1.7.bias', 'lins.1.7.running_mean', 'lins.1.7.running_var', 'lins.1.7.num_batches_tracked', 'lins.2.1.weight', 'lins.2.1.bias', 'lins.2.3.weight', 'lins.2.3.bias', 'lins.2.3.running_mean', 'lins.2.3.running_var', 'lins.2.3.num_batches_tracked', 'lins.2.5.weight', 'lins.2.5.bias', 'lins.2.7.weight', 'lins.2.7.bias', 'lins.2.7.running_mean', 'lins.2.7.running_var', 'lins.2.7.num_batches_tracked', 'lins.3.1.weight', 'lins.3.1.bias', 'l

In [13]:
sd['model_state_dict']['lins.0.7.weight']

tensor([ 0.3976, -0.4134, -0.4364,  ...,  0.3736, -0.4621,  0.4127])

In [15]:
sd['model_state_dict']['lins.0.7.weight'].mean(), sd['model_state_dict']['lins.0.7.weight'].std()

(tensor(-0.0140), tensor(0.4496))

In [14]:
sd['model_state_dict']['lins.4.7.weight']

tensor([ 0.6824, -0.7241, -0.5594,  ..., -0.6872, -0.8073, -0.7000])

In [2]:
# if full_training is True, use large batches and the entire training dataset 
full_training = True
# image augmentation just for the CLIP image model that will be more semantic-focused
train_augs = AugmentationSequential(
    kornia.augmentation.RandomCrop((140, 140), p=0.3),
    kornia.augmentation.Resize((224, 224)),
    kornia.augmentation.RandomHorizontalFlip(p=0.5),
    kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
    kornia.augmentation.RandomGrayscale(p=0.3),
    data_keys=["input"],
    # random_apply = (1,4)
)

print('full_training:',full_training)

model_name = 'clip_image_vitB' # CLIP ViT-L/14 image embeddings
print(f"Using model: {model_name}")
    
if "resnet" in model_name: 
    clip_extractor = Clipper("RN50")
else:
    clip_extractor = Clipper("ViT-L/14", train_transforms=train_augs)
    
if "text" in model_name:     
    image_var = 'trial' 
else:
    image_var = 'images'
print("image_var =", image_var)

full_training: True
Using model: clip_image_vitB
ViT-L/14 cuda


100%|███████████████████████████████████████| 890M/890M [00:17<00:00, 53.2MiB/s]


image_var = images


In [3]:
nsd_path = '../../naturalscenesdataset/webdataset/'

if not full_training: 
    num_devices = 1
    num_workers = 1
    print("num_workers",num_workers)
    batch_size = 16
    print("batch_size",batch_size)
    num_samples = 500 
    global_batch_size = batch_size * num_devices
    print("global_batch_size",global_batch_size)
    num_batches = math.floor(num_samples / global_batch_size)
    num_worker_batches = math.floor(num_batches / num_workers)
    print("num_worker_batches",num_worker_batches)
    train_url = f"{nsd_path}/train/train_subj01_{{0..1}}.tar"
else:
    num_devices = torch.cuda.device_count()
    print("num_devices",num_devices)
    num_workers = num_devices
    print("num_workers",num_workers)
    batch_size = 300
    print("batch_size",batch_size)
    num_samples = 24983 # see metadata.json in webdataset_split folder
    global_batch_size = batch_size * num_devices
    print("global_batch_size",global_batch_size)
    num_batches = math.floor(num_samples / batch_size)
    num_worker_batches = math.floor(num_batches / num_workers)
    print("num_worker_batches",num_worker_batches)
    train_url = f"{nsd_path}/train/train_subj01_{{0..49}}.tar"

train_data = wds.DataPipeline([wds.ResampledShards(train_url),
                    wds.tarfile_to_samples(),
                    wds.shuffle(500,initial=500),
                    wds.decode("torch"),
                    wds.rename(images="jpg;png", voxels="nsdgeneral.npy", embs="sgxl_emb.npy", trial="trial.npy"),
                    wds.to_tuple("voxels", image_var),
                    wds.batched(batch_size, partial=True),
                ]).with_epoch(num_worker_batches)
train_dl = wds.WebLoader(train_data, num_workers=num_workers,
                         batch_size=None, shuffle=False, persistent_workers=True)

# Validation #
num_samples = 492
num_batches = math.ceil(num_samples / global_batch_size)
num_worker_batches = math.ceil(num_batches / num_workers)
print("validation: num_worker_batches",num_worker_batches)

url = f"{nsd_path}/val/val_subj01_0.tar"
val_data = wds.DataPipeline([wds.ResampledShards(url),
                    wds.tarfile_to_samples(),
                    wds.decode("torch"),
                    wds.rename(images="jpg;png", voxels="nsdgeneral.npy", 
                                embs="sgxl_emb.npy", trial="trial.npy"),
                    wds.to_tuple("voxels", image_var),
                    wds.batched(batch_size, partial=True),
                ]).with_epoch(num_worker_batches)
val_dl = wds.WebLoader(val_data, num_workers=num_workers,
                       batch_size=None, shuffle=False, persistent_workers=True)

num_devices 0
num_workers 0
batch_size 300
global_batch_size 0


ZeroDivisionError: division by zero

In [3]:
# check that your data loaders are working
out_dim = 512
for train_i, (voxel, img_input) in enumerate(train_dl):
    print("idx",train_i)
    print("voxel.shape",voxel.shape)
    if "text" in model_name:
        emb = clip_extractor.embed_curated_annotations(subj01_annots[img_input])
    else:
        emb = clip_extractor.embed_image(img_input)
    print("emb.shape",emb.shape)
    out_dim = emb.shape[1]
    print("out_dim", out_dim)
    break

NameError: name 'train_dl' is not defined

In [6]:
brain_net = BrainNetwork(out_dim=512)
EPOCHS = 100
opt = torch.optim.AdamW(brain_net.parameters(), lr=1e-3)
sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=3e-4, 
                                            total_steps=EPOCHS*((24983//300)//num_devices), 
                                            final_div_factor=1000,
                                            last_epoch=-1, pct_start=2/EPOCHS)

nce = InfoNCE()

NameError: name 'num_devices' is not defined

In [7]:
torch.numel(torch.nn.utils.parameters_to_vector(brain_net.parameters()))

133747200

In [4]:
brain_net

BrainNetwork(
  (conv): Sequential(
    (0): Linear(in_features=15742, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (lin): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=4096, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.15, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=4096, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.15, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=4096, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [None]:
epoch = 0

train_losses = []; val_losses = []
train_topk = []; val_topk = []
lrs = []
epoch_logs = []

print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")

if full_training:
    print(f"Will be saving model checkpoints to checkpoints/{model_name}_subj01_epoch#.pth")
else:
    print(f"Warning: not saving model checkpoints")

if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")
    
pbar = tqdm(range(epoch, 100), ncols=250)
for epoch in pbar:
    brain_net.train()
    similarities = []
    for train_i, (voxel, img_input) in enumerate(train_dl):
        optimizer.zero_grad()
        
        voxel = voxel.to(device)
        
        with torch.cuda.amp.autocast():
            with torch.bo_grad():
                if image_var=='images': # using images
                    emb = clip_extractor.embed_image(img_input)
                else: # using text captions of the images 
                    emb = clip_extractor.embed_curated_annotations(subj01_annots[img_input])

        emb = emb.float() # cast to float32
        emb_ = brain_net(voxel)
            
        if torch.any(torch.isnan(emb_)):
            raise ValueError("NaN found...")
                
        emb_ = nn.functional.normalize(emb_,dim=-1) # l2 normalization on the embeddings
            
        labels = torch.arange(len(emb)).to(device)
        loss_nce = nce(emb_.reshape(len(emb),-1),emb.reshape(len(emb),-1))
        loss_soft = soft_clip_loss(emb_.reshape(len(emb),-1),emb.reshape(len(emb),-1))
        loss = loss_nce + loss_soft
            
        similarities = batchwise_cosine_similarity(emb,emb_)

        percent_correct = topk(similarities,labels,k=1)

        loss.backward()
        optimizer.step()
        sched.step()
        
        train_losses.append(loss.item())
        train_topk.append(percent_correct.item())
        
    brain_net.eval()    
    for val_i, (val_voxel, val_img_input) in enumerate(val_dl):
        with torch.no_grad(): 
            val_voxel = val_voxel.to(device)
            
            with torch.cuda.amp.autocast():
                if image_var=='images': # using images
                    val_emb = clip_extractor.embed_image(val_img_input)
                else: # using text captions of the images 
                    val_emb = clip_extractor.embed_curated_annotations(subj01_annots[val_img_input])

                val_emb_ = brain_net(val_voxel)
                val_emb_ = nn.functional.normalize(val_emb_,dim=-1) # l2 normalization on the embeddings
            
                labels = torch.arange(len(val_emb)).to(device)

                val_loss = nce(val_emb_.reshape(len(val_emb),-1),val_emb.reshape(len(val_emb),-1))

                val_similarities = batchwise_cosine_similarity(val_emb,val_emb_)

                val_percent_correct = topk(val_similarities,labels,k=1)
                
            val_losses.append(val_loss.item())
            val_topk.append(val_percent_correct.item())
                
    if epoch%5==4 and full_training:
        print(f'saving checkpoints/{model_name}_subj01_epoch{epoch+1}.pth...')
        if (using_ddp==False) or (using_ddp==True and local_rank==0):
            state_dict = brain_net.state_dict()
            if using_ddp: # if using DDP, convert DDP to non-DDP before saving
                state_dict = brain_net.module.state_dict()
            torch.save({
                'epoch': epoch,
                'model_state_dict': state_dict,
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'train_topk': train_topk,
                'val_topk': val_topk,
                'lrs': lrs,
                }, f'checkpoints/{model_name}_subj01_epoch{epoch}.pth')
        if using_ddp:
            dist.barrier() # this tells the other gpus wait for the first gpu to finish saving the model
            
    lrs.append(optimizer.param_groups[0]['lr'])
    
    # logging the average results across batches for current epoch
    logs = OrderedDict(
        loss=np.mean(train_losses[-(train_i+1):]),
        topk=np.mean(train_topk[-(train_i+1):]),
        val_loss=np.mean(val_losses[-(val_i+1):]),
        val_topk=np.mean(val_topk[-(val_i+1):]),
        lr=lrs[-1],
    )
    pbar.set_postfix(**logs)
    epoch_logs.append(logs)
    if full_training:
        pd.DataFrame(epoch_logs).to_csv(f'checkpoints/{model_name}_subj01.epoch-logs.csv')
    
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))