In [38]:
%load_ext autoreload
%autoreload 2

import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm
from datetime import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
print("device:",device)

import utils
seed=42
utils.seed_everything(seed=seed)

# Load CLIP extractor
# from models import OpenClipper
# clip_extractor = OpenClipper("ViT-H-14", device=device)
# imsize = 768

from models import Clipper
clip_extractor = Clipper("ViT-L/14", hidden_state=False, norm_embs=True, device=device)
imsize = 512

# all_brain_recons = torch.load('evals/all_brain_recons')
# all_brain_recons = torch.load('evals/all_brain_retrievals')
# all_brain_recons = torch.load('evals/all_brain_imgtext_recons')

model_name = 'autoencoder_subj01_16x' #'prior_257_final_subj01_bimixco_softclip_byol'
all_brain_recons = torch.load(f'evals/{model_name}_brain_recons') #v2cranking img2img
# all_brain_recons = torch.load('evals/all_brain_recons')
# all_brain_recons = torch.load('evals/autoencoder_recons') 
# all_brain_recons = torch.load('evals/blurry_recons') 
# all_clip_recons = torch.load('evals/all_clip_recons')
all_images = torch.load('evals/all_images')
# all_laion_picks = torch.load('evals/all_laion_picks')

# model_name = 'prior_257_final_subj01_bimixco_softclip_byol'
# all_brain_recons = torch.load(f'evals/{model_name}_laion_retrievals') #v2cranking img2img

# all_brain_recons = torch.load(f'evals/braindiffuser_brain_recons')
# all_brain_recons = torch.load(f'evals/braindiffuser_brain_recons_no_vdvae')
# all_brain_recons = torch.load(f'evals/braindiffuser_vdvae')
# all_images = torch.Tensor(np.load('brain-diffuser/data/processed_data/subj01/nsd_test_stim_sub1.npy')/255.).permute(0,3,1,2)

print(all_images.shape)
print(all_brain_recons.shape)

all_images = all_images.to(device)
all_brain_recons = all_brain_recons.to(device).to(all_images.dtype).clamp(0,1)

# all_images = transforms.Resize((425,425))(all_images)
# all_brain_recons = transforms.Resize((425,425))(all_brain_recons)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
device: cuda
ViT-L/14 cuda
torch.Size([982, 3, 256, 256])
torch.Size([982, 3, 512, 512])


In [1]:
# for ii in [0,5,13,15,18,8,45,77,114,159,225,342,394,451,467,500,531,609,776,882,916]:
#     print(ii)
#     display(utils.torch_to_Image(all_images[ii]))
#     display(utils.torch_to_Image(all_brain_recons[ii]))

# 2-Way Identification

In [48]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

@torch.no_grad()
def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True):
    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(all_images), len(all_images):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    if return_avg:
        perf = np.mean(success_cnt) / (len(all_images)-1)
        return perf
    else:
        return success_cnt, len(all_images)-1

## PixCorr

In [49]:
preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])

# Flatten images while keeping the batch dimension
all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()
all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()

print(all_images_flattened.shape)
print(all_brain_recons_flattened.shape)

corrsum = 0
for i in tqdm(range(982)):
    corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]
corrmean = corrsum / 982

print(corrmean)

torch.Size([982, 541875])
torch.Size([982, 541875])


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 982/982 [00:07<00:00, 136.35it/s]

0.4462545378979282





## SSIM

In [36]:
# see https://github.com/zijin-gu/meshconv-decoding/issues/3
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim

preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), 
])

# convert image to grayscale with rgb2grey
img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())
recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu())
print("converted, now calculating ssim...")

ssim_score=[]
for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):
    ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))
print(np.mean(ssim_score))

converted, now calculating ssim...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 982/982 [00:14<00:00, 65.53it/s]

0.4886934943681578





### AlexNet

In [37]:
from torchvision.models import alexnet, AlexNet_Weights
alex_weights = AlexNet_Weights.IMAGENET1K_V1

alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)
alex_model.eval().requires_grad_(False)

# see alex_weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

layer = 'early, AlexNet(2)'
print(f"\n---{layer}---")
all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                          alex_model, preprocess, 'features.4')
print(f"2-way Percent Correct: {np.mean(all_per_correct):.4f}")

layer = 'mid, AlexNet(5)'
print(f"\n---{layer}---")
all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                          alex_model, preprocess, 'features.11')
print(f"2-way Percent Correct: {np.mean(all_per_correct):.4f}")


---early, AlexNet(2)---
2-way Percent Correct: 0.8397

---mid, AlexNet(5)---
2-way Percent Correct: 0.7677


### InceptionV3

In [None]:
from torchvision.models import inception_v3, Inception_V3_Weights
weights = Inception_V3_Weights.DEFAULT
inception_model = create_feature_extractor(inception_v3(weights=weights), 
                                           return_nodes=['avgpool']).to(device)
inception_model.eval().requires_grad_(False)

# see weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

all_per_correct = two_way_identification(all_brain_recons, all_images,
                                        inception_model, preprocess, 'avgpool')
        
print(f"2-way Percent Correct: {np.mean(all_per_correct):.4f}")

### CLIP

In [None]:
import clip
clip_model, preprocess = clip.load("ViT-L/14", device=device)

preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])

all_per_correct = two_way_identification(all_brain_recons, all_images,
                                        clip_model.encode_image, preprocess, None) # final layer
print(f"2-way Percent Correct: {np.mean(all_per_correct):.4f}")

### Efficient Net

In [None]:
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
weights = EfficientNet_B1_Weights.DEFAULT
eff_model = create_feature_extractor(efficientnet_b1(weights=weights), 
                                    return_nodes=['avgpool']).to(device)
eff_model.eval().requires_grad_(False)

# see weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

gt = eff_model(preprocess(all_images))['avgpool']
gt = gt.reshape(len(gt),-1).cpu().numpy()
fake = eff_model(preprocess(all_brain_recons))['avgpool']
fake = fake.reshape(len(fake),-1).cpu().numpy()

print("Distance:",np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean())

### SWAV

In [None]:
swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
swav_model = create_feature_extractor(swav_model, 
                                    return_nodes=['avgpool']).to(device)
swav_model.eval().requires_grad_(False)

preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

gt = swav_model(preprocess(all_images))['avgpool']
gt = gt.reshape(len(gt),-1).cpu().numpy()
fake = swav_model(preprocess(all_brain_recons))['avgpool']
fake = fake.reshape(len(fake),-1).cpu().numpy()

print("Distance:",np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean())