In [4]:
# # Code to convert this notebook to .py if you want to run it via command line or with Slurm
# from subprocess import call
# command = "jupyter nbconvert Recon_Evaluation.ipynb --to python"
# call(command,shell=True)

[NbConvertApp] Converting notebook Recon_Evaluation.ipynb to python
[NbConvertApp] Writing 16306 bytes to Recon_Evaluation.py


0

In [1]:
import os
import cv2
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm
from datetime import datetime
import webdataset as wds
import PIL

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
print("device:",device)

import utils
from models import Clipper, OpenClipper, BrainNetwork, BrainDiffusionPrior, BrainDiffusionPriorOld, Voxel2StableDiffusionModel, VersatileDiffusionPriorNetwork

if utils.is_interactive():
    %load_ext autoreload
    %autoreload 2

seed=42
utils.seed_everything(seed=seed)

subj = 2 #note: we only trained subjects 1 2 5 7, since they have data across full sessions
if not utils.is_interactive():
    subj = int(sys.argv[1])
    print("Subj = ",subj)
else:
    print("Interactive. Subj = ",subj)
if subj == 1:
    num_voxels = 15724
elif subj == 2:
    num_voxels = 14278
elif subj == 3:
    num_voxels = 15226
elif subj == 4:
    num_voxels = 13153
elif subj == 5:
    num_voxels = 13039
elif subj == 6:
    num_voxels = 17907
elif subj == 7:
    num_voxels = 12682
elif subj == 8:
    num_voxels = 14386
print("subj",subj,"num_voxels",num_voxels)

print("PID of this process=",os.getpid())

device: cuda
Interactive. Subj =  2
subj 2 num_voxels 14278


In [2]:
data_path = "/fsx/proj-medarc/fmri/natural-scenes-dataset"
val_url = f"{data_path}/webdataset_avg_split/test/test_subj0{subj}_" + "{0..1}.tar"
meta_url = f"{data_path}/webdataset_avg_split/metadata_subj0{subj}.json"
num_train = 8559 + 300
num_val = 982
batch_size = val_batch_size = 1
voxels_key = 'nsdgeneral.npy' # 1d inputs

val_data = wds.WebDataset(val_url, resampled=False)\
    .decode("torch")\
    .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
    .to_tuple("voxels", "images", "coco")\
    .batched(val_batch_size, partial=False)

val_dl = torch.utils.data.DataLoader(val_data, batch_size=None, shuffle=False)

# check that your data loader is working
for val_i, (voxel, img_input, coco) in enumerate(val_dl):
    print("idx",val_i)
    print("voxel.shape",voxel.shape)
    print("img_input.shape",img_input.shape)
    break

idx 0
voxel.shape torch.Size([1, 3, 14278])
img_input.shape torch.Size([1, 3, 256, 256])


## Load autoencoder

In [3]:
from diffusers.models.vae import Decoder
class Voxel2StableDiffusionModel(torch.nn.Module):
    def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False, ups_mode='4x'):
        super().__init__()
        self.lin0 = nn.Sequential(
            nn.Linear(in_dim, h, bias=False),
            nn.LayerNorm(h),
            nn.SiLU(inplace=True),
            nn.Dropout(0.5),
        )

        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(h, h, bias=False),
                nn.LayerNorm(h),
                nn.SiLU(inplace=True),
                nn.Dropout(0.25)
            ) for _ in range(n_blocks)
        ])
        self.ups_mode = ups_mode
        if ups_mode=='4x':
            self.lin1 = nn.Linear(h, 16384, bias=False)
            self.norm = nn.GroupNorm(1, 64)
            
            self.upsampler = Decoder(
                in_channels=64,
                out_channels=4,
                up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
                block_out_channels=[64, 128, 256],
                layers_per_block=1,
            )

            if use_cont:
                self.maps_projector = nn.Sequential(
                    nn.Conv2d(64, 512, 1, bias=False),
                    nn.GroupNorm(1,512),
                    nn.ReLU(True),
                    nn.Conv2d(512, 512, 1, bias=False),
                    nn.GroupNorm(1,512),
                    nn.ReLU(True),
                    nn.Conv2d(512, 512, 1, bias=True),
                )
            else:
                self.maps_projector = nn.Identity()
        
        if ups_mode=='8x':  # prev best
            self.lin1 = nn.Linear(h, 16384, bias=False)
            self.norm = nn.GroupNorm(1, 256)
            
            self.upsampler = Decoder(
                in_channels=256,
                out_channels=4,
                up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
                block_out_channels=[64, 128, 256, 256],
                layers_per_block=1,
            )
            self.maps_projector = nn.Identity()
        
        if ups_mode=='16x':
            self.lin1 = nn.Linear(h, 8192, bias=False)
            self.norm = nn.GroupNorm(1, 512)
            
            self.upsampler = Decoder(
                in_channels=512,
                out_channels=4,
                up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"],
                block_out_channels=[64, 128, 256, 256, 512],
                layers_per_block=1,
            )
            self.maps_projector = nn.Identity()

    def forward(self, x, return_transformer_feats=False):
        x = self.lin0(x)
        residual = x
        for res_block in self.mlp:
            x = res_block(x)
            x = x + residual
            residual = x
        x = x.reshape(len(x), -1)
        x = self.lin1(x)  # bs, 4096

        if self.ups_mode == '4x':
            side = 16
        if self.ups_mode == '8x':
            side = 8
        if self.ups_mode == '16x':
            side = 4
        
        # decoder
        x = self.norm(x.reshape(x.shape[0], -1, side, side).contiguous())
        if return_transformer_feats:
            return self.upsampler(x), self.maps_projector(x).flatten(2).permute(0,2,1)
        return self.upsampler(x)

voxel2sd = Voxel2StableDiffusionModel(in_dim=num_voxels)

model_name = f'autoencoder_subj0{subj}_4x_locont_no_reconst/test'
ckpt_path = f'/fsx/proj-medarc/fmri/fMRI-reconstruction-NSD/train_logs/models/{model_name}/epoch120.pth'
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint['model_state_dict']

# state_dict = torch.load('../train_logs/autoencoder/last.pth', 
#                         map_location='cpu')
# print(state_dict["epoch"])
# state_dict = state_dict["model_state_dict"]

voxel2sd.load_state_dict(state_dict,strict=False)
voxel2sd.eval()
voxel2sd.to(device)
print("Loaded low-level model!")

Loaded low-level model!


## Load Versatile Diffusion model

In [4]:
out_dim = 257 * 768
clip_extractor = Clipper("ViT-L/14", hidden_state=True, norm_embs=True, device=device)
voxel2clip_kwargs = dict(in_dim=num_voxels,out_dim=out_dim)
voxel2clip = BrainNetwork(**voxel2clip_kwargs)
voxel2clip.requires_grad_(False)
voxel2clip.eval()

out_dim = 768
depth = 6
dim_head = 64
heads = 12 # heads * dim_head = 12 * 64 = 768
timesteps = 100

prior_network = VersatileDiffusionPriorNetwork(
        dim=out_dim,
        depth=depth,
        dim_head=dim_head,
        heads=heads,
        causal=False,
        learned_query_mode="pos_emb"
    )

diffusion_prior = BrainDiffusionPrior(
    net=prior_network,
    image_embed_dim=out_dim,
    condition_on_text_encodings=False,
    timesteps=timesteps,
    cond_drop_prob=0.2,
    image_embed_scale=None,
    voxel2clip=voxel2clip,
)

# model_name = "prior_257_subj01"
# outdir = f'../train_logs/{model_name}'
# ckpt_path = os.path.join(outdir, f'last.pth')

model_name = f'prior_257_final_subj0{subj}_bimixco_softclip_byol'
# if not utils.is_interactive():
#     model_name = sys.argv[2]
print("Model name:",model_name)
ckpt_path = f'/fsx/proj-medarc/fmri/fMRI-reconstruction-NSD/train_logs/models/{model_name}/epoch239.pth'

print("ckpt_path",ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint['model_state_dict']
print("EPOCH: ",checkpoint['epoch'])
diffusion_prior.load_state_dict(state_dict,strict=False)
diffusion_prior.eval().to(device)
diffusion_priors = [diffusion_prior]
pass

ViT-L/14 cuda
Model name: prior_257_final_subj02_bimixco_softclip_byol
ckpt_path /fsx/proj-medarc/fmri/fMRI-reconstruction-NSD/train_logs/models/prior_257_final_subj02_bimixco_softclip_byol/epoch239.pth
EPOCH:  239


In [5]:
# CLS model
out_dim = 768
voxel2clip_kwargs = dict(in_dim=num_voxels,out_dim=out_dim)
voxel2clip_cls = BrainNetwork(**voxel2clip_kwargs)
voxel2clip_cls.requires_grad_(False)
voxel2clip_cls.eval()

diffusion_prior_cls = BrainDiffusionPriorOld.from_pretrained(
    # kwargs for DiffusionPriorNetwork
    dict(),
    # kwargs for DiffusionNetwork
    dict(
        condition_on_text_encodings=False,
        timesteps=1000,
        voxel2clip=voxel2clip_cls,
    ),
    voxel2clip_path=None,
)

model_name_cls = f"final_subj0{subj}_1x768"
outdir = f'../train_logs/{model_name_cls}'
ckpt_path = os.path.join(outdir, f'last.pth')
print("ckpt_path",ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint['model_state_dict']
print("EPOCH: ",checkpoint['epoch'])
diffusion_prior_cls.load_state_dict(state_dict,strict=False)
diffusion_prior_cls.eval().to(device)
pass

ckpt_path ../train_logs/final_subj02_1x768/last.pth
EPOCH:  299


In [6]:
from diffusers import VersatileDiffusionDualGuidedPipeline, UniPCMultistepScheduler
from diffusers.models import DualTransformer2DModel
vd_cache_dir = '/fsx/proj-medarc/fmri/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
try:
    vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(
            # "lambdalabs/sd-image-variations-diffusers",
            vd_cache_dir,
            safety_checker=None,
            requires_safety_checker=False,
        ).to('cpu').to(torch.float16)
except:
    vd_cache_dir = 'specify a path to save the pretrained Versatile Diffusion model'
    print("vd_cache_dir", vd_cache_dir)
    vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(
            "lambdalabs/sd-image-variations-diffusers",
            safety_checker=None,
            requires_safety_checker=False,
            cache_dir = vd_cache_dir,
        )
vd_pipe.image_unet.eval()
vd_pipe.vae.eval()
vd_pipe.image_unet.requires_grad_(False)
vd_pipe.vae.requires_grad_(False)

vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained(vd_cache_dir, subfolder="scheduler")
num_inference_steps = 20

# Set weighting of Dual-Guidance 
text_image_ratio = .0 # .5 means equally weight text and image, 0 means use only image
for name, module in vd_pipe.image_unet.named_modules():
    if isinstance(module, DualTransformer2DModel):
        module.mix_ratio = text_image_ratio
        for i, type in enumerate(("text", "image")):
            if type == "text":
                module.condition_lengths[i] = 77
                module.transformer_index_for_condition[i] = 1  # use the second (text) transformer
            else:
                module.condition_lengths[i] = 257
                module.transformer_index_for_condition[i] = 0  # use the first (image) transformer

unet = vd_pipe.image_unet.to(device)
vae = vd_pipe.vae.to(device)
noise_scheduler = vd_pipe.scheduler
img_variations = False

Keyword arguments {'safety_checker': None, 'requires_safety_checker': False} are not expected by VersatileDiffusionDualGuidedPipeline and will be ignored.


## Load Image Variations model

In [6]:
# # CLS model
# out_dim = 768
# voxel2clip_kwargs = dict(in_dim=num_voxels,out_dim=out_dim)
# voxel2clip = BrainNetwork(**voxel2clip_kwargs)
# voxel2clip.requires_grad_(False)
# voxel2clip.eval()

# diffusion_prior = BrainDiffusionPriorOld.from_pretrained(
#     # kwargs for DiffusionPriorNetwork
#     dict(),
#     # kwargs for DiffusionNetwork
#     dict(
#         condition_on_text_encodings=False,
#         timesteps=1000,
#         voxel2clip=voxel2clip,
#     ),
#     voxel2clip_path=None,
# )

# model_name = "final_subj01_1x768"
# outdir = f'../train_logs/{model_name}'
# ckpt_path = os.path.join(outdir, f'last.pth')
# print("ckpt_path",ckpt_path)
# checkpoint = torch.load(ckpt_path, map_location=device)
# state_dict = checkpoint['model_state_dict']
# print("EPOCH: ",checkpoint['epoch'])
# diffusion_prior.load_state_dict(state_dict,strict=False)
# diffusion_prior.eval().to(device)
# pass

In [7]:
# from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler

# sd_cache_dir = '/fsx/home-paulscotti/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/a2a13984e57db80adcc9e3f85d568dcccb9b29fc'
# unet = UNet2DConditionModel.from_pretrained(sd_cache_dir,subfolder="unet").to(device)

# unet.eval() # dont want to train model
# unet.requires_grad_(False) # dont need to calculate gradients

# vae = AutoencoderKL.from_pretrained(sd_cache_dir,subfolder="vae").to(device)
# vae.eval()
# vae.requires_grad_(False)

# noise_scheduler = UniPCMultistepScheduler.from_pretrained(sd_cache_dir, subfolder="scheduler")
# num_inference_steps = 20

# img_variations = True

# Reconstruct one-at-a-time

In [7]:
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

recons_per_sample = 16
img2img = True
img2img_strength = .85
retrieve = False
plotting = False
saving = True
verbose = False
imsize = 512

if img_variations:
    guidance_scale = 7.5
else:
    guidance_scale = 3.5
    
ind_include = np.arange(num_val)

for img2img_strength in [.85, .7, .5, .3, 0]:
    print("img2img_strength",img2img_strength)
    all_brain_recons = None
    if img2img_strength == 0:
        img2img = False
    else:
        img2img = True
    for val_i, (voxel, img, coco) in enumerate(tqdm(val_dl,total=len(ind_include))):
        if val_i<np.min(ind_include):
            continue
        voxel = torch.mean(voxel,axis=1).to(device)
        # voxel = voxel[:,0].to(device)
        with torch.no_grad():
            if img2img:
                ae_preds = voxel2sd(voxel.float())
                blurry_recons = vd_pipe.vae.decode(ae_preds.to(device).half()/0.18215).sample / 2 + 0.5

                if val_i==0:
                    plt.imshow(utils.torch_to_Image(blurry_recons))
                    plt.show()

                # blurry_recons = PIL.Image.open(f"blurry_recons/{coco.item()}.png").convert('RGB')
                # blurry_recons = transforms.PILToTensor()(blurry_recons)
                # blurry_recons = transforms.Resize((512,512))(blurry_recons)
                # blurry_recons = (blurry_recons.float() / 255)[None]
            else:
                blurry_recons = None

            grid, brain_recons, laion_best_picks, recon_img = utils.reconstruction(
                img, voxel,
                clip_extractor, unet, vae, noise_scheduler,
                voxel2clip_cls = diffusion_prior_cls.voxel2clip,
                diffusion_priors = diffusion_priors,
                text_token = None,
                img_lowlevel = blurry_recons,
                num_inference_steps = num_inference_steps,
                n_samples_save = batch_size,
                recons_per_sample = recons_per_sample,
                guidance_scale = guidance_scale,
                img2img_strength = .85, # 0=fully rely on img_lowlevel, 1=not doing img2img
                timesteps_prior = 100,
                seed = seed,
                retrieve = retrieve,
                plotting = plotting,
                img_variations = img_variations,
                verbose = verbose,
            )

            if plotting:
                plt.show()
                # grid.savefig(f'evals/{model_name}_{val_i}.png')
                # plt.close()

            brain_recons = brain_recons[:,laion_best_picks.astype(np.int8)]
            
            if all_brain_recons is None:
                all_brain_recons = brain_recons
                all_images = img
            else:
                all_brain_recons = torch.vstack((all_brain_recons,brain_recons))
                all_images = torch.vstack((all_images,img))

        if val_i>=np.max(ind_include):
            break

    all_brain_recons = all_brain_recons.view(-1,3,imsize,imsize)
    print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

    if saving:
        # torch.save(all_images,'evals/all_images')

        # print("BLURRY RECONS. CHECK THE -5 INDEX")
        # torch.save(all_brain_recons,f'evals/{model_name[:-5]}_brain_recons_full_img2img{img2img_strength}')

        torch.save(all_brain_recons,f'evals/{model_name}_brain_recons_full_img2img{img2img_strength}')
    print(f'evals/{model_name}_brain_recons_full_img2img{img2img_strength} done!')
print("DONE!")
sys.exit(0)

2023-05-07 21:32:31


  0%|                                                                                                     | 0/982 [00:31<?, ?it/s]


NameError: name 'err' is not defined