In [None]:
# # Code to convert this notebook to .py if you want to run it via command line or with Slurm
# from subprocess import call
# command = "jupyter nbconvert Reconstruction_Metrics.ipynb --to python"
# call(command,shell=True)

In [None]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm import tqdm
from datetime import datetime
import argparse

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
local_rank = 0
print("device:",device)

import utils
seed=42
utils.seed_everything(seed=seed)

if utils.is_interactive():
    %load_ext autoreload
    %autoreload 2

# from models import Clipper
# clip_extractor = Clipper("ViT-L/14", hidden_state=False, norm_embs=True, device=device)
imsize = 512

In [None]:
import os
from PIL import Image
import torch
import numpy as np

istest = True
sub = 4
# Define the source and target directories
if istest:  
    source_dir = '../../Generation/generated_imgs'+'/'+str(sub)
else:
    source_dir = '../../Generation/val_generated_imgs'+'/'+str(sub)
target_dir = '../../Generation/generated_imgs_tensor'

# Create the target directory if it doesn't exist
if not os.path.exists(target_dir):
    os.makedirs(target_dir)

# Initialize a list to hold all the image tensors
tensor_list = []
# Initialize a dictionary to map image file names to their categories
image_categories = {}

# Initialize the set to store unique category names
category_set = set()

# Iterate over the folders in the source directory
for folder_name in sorted(os.listdir(source_dir)):
    folder_path = os.path.join(source_dir, folder_name)

    # Extract the category name from the folder name (assuming category is part of folder_name)
    category_name = folder_name.split(' ')[-1]  # This splits the folder_name and takes the last part as the category
    category_set.add(category_name)  # Add the category name to the set

    # Check if it's a directory
    if os.path.isdir(folder_path):
        # Sort the image files to ensure consistent order
        image_files = sorted(os.listdir(folder_path))
        # Iterate over the sorted images in the folder
        for image_name in image_files:
            image_path = os.path.join(folder_path, image_name)

            # Map the image name to its category
            image_categories[image_name] = category_name

            # Load the image
            with Image.open(image_path) as img:
                # Convert the image to a PyTorch tensor and add a batch dimension
                tensor = torch.tensor(np.array(img)).unsqueeze(0)
                tensor_list.append(tensor)

# Concatenate all tensors along the 0th dimension
all_tensors = torch.cat(tensor_list, dim=0)

# Save the combined tensor
combined_tensor_path = os.path.join(target_dir, "all_images.pt")
torch.save(all_tensors, combined_tensor_path)

# Now we sort the category names and write them to a file
# along with their associated image names
categories_path = os.path.join(target_dir, "categories.txt")
with open(categories_path, 'w') as f:
    for image_name, category_name in sorted(image_categories.items()):
        f.write(f"{image_name}: {category_name}\n")

# Print out the category list, now sorted
category_list = sorted(list(category_set))
print(category_list)

In [None]:
import os
from PIL import Image
import torch
import numpy as np
from einops import rearrange

splits_path = "../../../datasets/block_splits_by_image_all.pth"#++
# splits_path = "../../../datasets/block_splits_by_image_single.pth"#++
eeg_signals_path="../../../datasets/eeg_5_95_std.pth"#++.
# eeg_signals_path="../../../datasets/eeg_signals_raw_with_mean_std.pth"#++.
loaded = torch.load(eeg_signals_path)
images = loaded["images"]
data = [loaded['dataset'][i] for i in range(len(loaded['dataset']) ) if loaded['dataset'][i]['subject']==sub]
if istest:
    split_idx = torch.load(splits_path)["splits"][0]['test']
else:
    split_idx = torch.load(splits_path)["splits"][0]['val']
split_idx = [i for i in split_idx if i < len(data) and 450 <= data[i]["eeg"].size(1) <= 600]
# for i in split_idx:
#     print(i)
images = [images[data[i]["image"]] for i in split_idx]
image_paths = [os.path.join('../../../datasets/imageNet_images/', image.split('_')[0], image+'.JPEG') for image in images]
# Define the source and target directories

if istest:
    target_dir = '../../../datasets/test_images_tensor'
else:
    target_dir = '../../../datasets/val_images_tensor'
# Create the target directory if it doesn't exist
if not os.path.exists(target_dir):
    os.makedirs(target_dir)

# Initialize a list to hold all the image tensors
tensor_list = []
# Initialize a dictionary to map image file names to their categories
image_categories = {}

# Initialize the set to store unique category names
category_set = set()

img_transform = transforms.Compose([
    transforms.Resize((512, 512)), 
])
for image_path in sorted(image_paths):
    category_name = image_path.split('/')[-1].split('.')[0]
    category_set.add(category_name)
    image_name = image_path.split('/')[-1].split('.')[0]
    image_categories[image_name] = category_name
    with Image.open(image_path).convert('RGB')  as img:
        # Convert the image to a PyTorch tensor and add a batch dimension
        tensor = img_transform(torch.tensor(np.array(img)).permute(2,0,1)).permute(1,2,0).unsqueeze(0)
        tensor_list.append(tensor)

tensor = tensor.squeeze(0)

# Concatenate all tensors along the 0th dimension
all_tensors = torch.cat(tensor_list, dim=0)

# Save the combined tensor
combined_tensor_path = os.path.join(target_dir, "all_images.pt")
torch.save(all_tensors, combined_tensor_path)

# Now we sort the category names and write them to a file
# along with their associated image names
categories_path = os.path.join(target_dir, "categories.txt")
with open(categories_path, 'w') as f:
    for image_name, category_name in sorted(image_categories.items()):
        f.write(f"{image_name}: {category_name}\n")

# Print out the category list, now sorted
category_list = sorted(list(category_set))
print(category_list)

# Configurations

In [None]:
# # if running this interactively, can specify jupyter_args here for argparser to use
# if utils.is_interactive():
#     # Example use
#     jupyter_args = "--recon_path=prior_257_final_subj01_bimixco_softclip_byol_brain_recons_full_img2img0.85_16samples.pt"
    
#     jupyter_args = jupyter_args.split()
#     print(jupyter_args)

In [None]:
# parser = argparse.ArgumentParser(description="Model Training Configuration")
# parser.add_argument(
#     "--recon_path", type=str,
#     help="path to reconstructed/retrieved outputs",
# )
# parser.add_argument(
#     "--all_images_path", type=str, default="all_images.pt",
#     help="path to ground truth outputs",
# )

# if utils.is_interactive():
#     args = parser.parse_args(jupyter_args)
# else:
#     args = parser.parse_args()

# # create global variables without the args prefix
# for attribute_name in vars(args).keys():
#     globals()[attribute_name] = getattr(args, attribute_name)

In [None]:
recon_path = '../../Generation/generated_imgs_tensor/all_images.pt'
if istest:
    all_images_path = '../../../datasets/test_images_tensor/all_images.pt'
else:
    all_images_path = '../../../datasets/val_images_tensor/all_images.pt'
all_brain_recons = torch.load(f'{recon_path}')
all_images = torch.load(f'{all_images_path}')
all_brain_recons = all_brain_recons[::1]


all_images = all_images.to(device)

all_images = all_images.transpose(1, 3)
all_images = all_images.transpose(2, 3)
all_brain_recons = all_brain_recons.transpose(1, 3)
all_brain_recons = all_brain_recons.transpose(2, 3)
print(all_images.shape)
print(all_brain_recons.shape)

all_images = all_images.to(device)
all_brain_recons = all_brain_recons.to(device)

import os
from PIL import Image
import torch

# 创建保存图像的目录
save_path = '../../testandrecon'
if not os.path.exists(save_path):
    os.makedirs(save_path)

# 将张量转换为图像并保存
def save_images(tensor, prefix, path):
    for i in range(tensor.size(0)):
        img = tensor[i].cpu().numpy().transpose(1, 2, 0)  # 将张量转换为numpy数组
        img = (img ).astype('uint8')  # 将图像转换为uint8类型
        img = Image.fromarray(img)
        img.save(os.path.join(path, f"{i}_{prefix}.png"))

# 保存all_images和all_brain_recons中的图像
save_images(all_images, 'test', save_path)
save_images(all_brain_recons, 'recon', save_path)



In [None]:
print("Minimum:", all_brain_recons.min().item())
print("Maximum:", all_brain_recons.max().item())
print("Mean:", all_brain_recons.to(dtype=torch.float32).mean().item())

In [None]:
# print(all_images)

In [None]:

# fig, axs = plt.subplots(ncols=200, squeeze=False,figsize=(64, 64))
# # plt.rcParams["savefig.bbox"] = 'tight'
# # fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
# for i in range(200):
#     img = all_images[i].detach()
#     img = transforms.ToPILImage()(img/255.0)
#     axs[0, i].imshow(np.asarray(img))
#     axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


# for i in range(200):
#     img = all_brain_recons[i].detach()
#     print(img)
#     img = transforms.ToPILImage()(img/255.0)
#     axs[1, i].imshow(np.asarray(img))
#     axs[1, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# # plt.subplots_adjust(wspace=0.1, hspace=0.001)
# # fig.tight_layout()
# # plt.subplots_adjust(wspace =0, hspace =0)#调整子图间距
# plt.savefig('Reconstruction.png')
# plt.show()


# Display reconstructions next to ground truth images

In [None]:
# # all_interleaved = all_interleaved.transpose(1, 3)
# all_interleaved.shape
# # all_interleaved

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np

# Assuming all_interleaved is a tensor with shape [N, C, H, W]
# where N is the number of images, C is the number of channels,
# H is the height, and W is the width of the images.

# Function to show a single image from all_interleaved
def show_single_image(img_tensor, figsize=(5, 5)):
    img_tensor = img_tensor.detach()
    # Convert the tensor to PIL image
    img = transforms.ToPILImage()(img_tensor/255.0)
    # print(img)
    # Plotting
    plt.figure(figsize=figsize)
    plt.imshow(np.asarray(img))
    plt.axis('off')  # Hide the axis
    plt.show()

# Show the first image from all_interleaved
# show_single_image(all_interleaved[0, :, :, :])

In [None]:
# all_interleaved[0].to(device) - all_images[0]

In [None]:
imsize = 256
all_images = transforms.Resize((imsize,imsize))(all_images)
all_brain_recons = transforms.Resize((imsize,imsize))(all_brain_recons)
# np.random.seed(0)
ind = np.flip(np.array([i for i in range(len(split_idx))]))
print(ind)

all_interleaved = torch.zeros(len(ind)*2,3,imsize,imsize).to(device)

icount = 0
for t in ind:
    all_interleaved[icount] = all_images[t].float().to(device)
    print("all_interleaved", all_interleaved[0])
    all_interleaved[icount+1] = all_brain_recons[t].float().to(device)
    icount += 2



plt.rcParams["savefig.bbox"] = 'tight'
def show(imgs,figsize):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
    for i, img in enumerate(imgs):
        print(i)
        img = img.detach()
        img = transforms.ToPILImage()(img/255.0)
        print(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        plt.show()
    
grid = make_grid(all_interleaved, nrow=10, padding=2)
# print(grid)
show(grid,figsize=(64,64))

In [None]:
# import torch
# from torchvision import transforms
# from torchvision.utils import make_grid
# import matplotlib.pyplot as plt
# import numpy as np

# # Assume all_images and all_brain_recons are defined and are lists or datasets of images
# batch_size = 10  # Adjust batch_size depending on your GPU memory
# num_batches = int(np.ceil(len(ind) / batch_size))
# ind = np.flip(np.array([112,119,101,44,159,22,173,174,175,189]))
# for batch_idx in range(num_batches):
#     batch_start = batch_idx * batch_size
#     batch_end = min(batch_start + batch_size, len(ind))
#     batch_indices = ind[batch_start:batch_end]

#     all_interleaved = torch.zeros(len(batch_indices)*2, 3, imsize, imsize)

#     icount = 0
#     for t in batch_indices:
#         img = transforms.Resize((imsize, imsize))(all_images[t])
#         recon = transforms.Resize((imsize, imsize))(all_brain_recons[t])

#         all_interleaved[icount] = img
#         all_interleaved[icount + 1] = recon
#         icount += 2

#     # Show or save the processed images here
#     grid = make_grid(all_interleaved, nrow=10, padding=2)
#     show(grid, figsize=(20, 16))

#     # Optional: Clear cache if you are still facing memory issues
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()

# 2-Way Identification

In [None]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

@torch.no_grad()
def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True):
    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(all_images), len(all_images):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    if return_avg:
        perf = np.mean(success_cnt) / (len(all_images)-1)
        return perf
    else:
        return success_cnt, len(all_images)-1

# 50WayTop1 Classification

In [None]:
from torchmetrics.functional import accuracy
from torchvision.models import ViT_H_14_Weights, vit_h_14
@torch.no_grad()
def n_way_top_k_acc(pred, class_id, n_way, num_trials=40, top_k=1):
    pick_range =[i for i in np.arange(len(pred)) if i != class_id]
    acc_list = []
    for t in range(num_trials):
        idxs_picked = np.random.choice(pick_range, n_way-1, replace=False)
        pred_picked = torch.cat([pred[class_id].unsqueeze(0), pred[idxs_picked]])
        acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), 
                    top_k=top_k)
        acc_list.append(acc.item())
    return np.mean(acc_list), np.std(acc_list)

@torch.no_grad()
def get_n_way_top_k_acc(pred_imgs, ground_truth, n_way, num_trials, top_k, device, return_std=False):
    weights = ViT_H_14_Weights.DEFAULT
    model = vit_h_14(weights=weights)
    preprocess = weights.transforms()
    model = model.to(device)
    model = model.eval()
    
    acc_list = []
    std_list = []
    for pred, gt in zip(pred_imgs, ground_truth):
        # pred = preprocess(Image.fromarray(pred.astype(np.uint8))).unsqueeze(0).to(device)
        # gt = preprocess(Image.fromarray(gt.astype(np.uint8))).unsqueeze(0).to(device)
        pred = preprocess(pred).unsqueeze(0).to(device)
        gt = preprocess(gt).unsqueeze(0).to(device)
        gt_class_id = model(gt).squeeze(0).softmax(0).argmax().item()
        pred_out = model(pred).squeeze(0).softmax(0).detach()

        acc, std = n_way_top_k_acc(pred_out, gt_class_id, n_way, num_trials, top_k)
        acc_list.append(acc)
        std_list.append(std)
       
    if return_std:
        return acc_list, std_list
    return acc_list
with torch.no_grad():
    result = get_n_way_top_k_acc(all_images, all_brain_recons, 50, 10, 1, device, return_std=False)
result
np.mean(result)

In [None]:
# result

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import transforms

# 定义图像尺寸
imsize = 256
# 调整图片尺寸
all_images = transforms.Resize((imsize, imsize))(all_images)
all_brain_recons = transforms.Resize((imsize, imsize))(all_brain_recons)

# 筛选result中大于0.5的索引
indices = [i for i, val in enumerate(result) if val > 0.5]
print("Indices with result > 0.5:", indices)

# 初始化交错张量
all_interleaved = torch.zeros(len(indices) * 2, 3, imsize, imsize).to(device)
icount = 0

for t in indices:
    all_interleaved[icount] = all_images[t].float().to(device)
    print("all_interleaved", all_interleaved[0])
    all_interleaved[icount + 1] = all_brain_recons[t].float().to(device)
    icount += 2

plt.rcParams["savefig.bbox"] = 'tight'

def show(imgs, figsize):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
    for i, img in enumerate(imgs):
        print(i)
        img = img.detach()
        img = transforms.ToPILImage()(img / 255.0)
        print(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

grid = make_grid(all_interleaved, nrow=10, padding=2)
show(grid, figsize=(64, 64))


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import transforms

# 定义图像尺寸
imsize = 256
# 调整图片尺寸
all_images = transforms.Resize((imsize, imsize))(all_images)
all_brain_recons = transforms.Resize((imsize, imsize))(all_brain_recons)

# 筛选result中大于0.5的索引
indices = [i for i, val in enumerate(result) if val < 0.5]
print("Indices with result < 0.5:", indices)

# 初始化交错张量
all_interleaved = torch.zeros(len(indices) * 2, 3, imsize, imsize).to(device)
icount = 0

for t in indices:
    all_interleaved[icount] = all_images[t].float().to(device)
    print("all_interleaved", all_interleaved[0])
    all_interleaved[icount + 1] = all_brain_recons[t].float().to(device)
    icount += 2

plt.rcParams["savefig.bbox"] = 'tight'

def show(imgs, figsize):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
    for i, img in enumerate(imgs):
        print(i)
        img = img.detach()
        img = transforms.ToPILImage()(img / 255.0)
        print(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

grid = make_grid(all_interleaved, nrow=10, padding=2)
show(grid, figsize=(64, 64))


## PixCorr

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((425, 425),interpolation=transforms.InterpolationMode.BILINEAR),
])

# Flatten images while keeping the batch dimension
all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()
all_brain_recons_flattened = preprocess(all_brain_recons).reshape(len(all_brain_recons), -1).cpu()

print(all_images_flattened.shape)
print(all_brain_recons_flattened.shape)

corrsum = 0
for i in tqdm(range(len(all_images_flattened))):
    corrsum += np.corrcoef(all_images_flattened[i].numpy(), all_brain_recons_flattened[i].numpy())[0][1]
corrmean = corrsum / len(all_images_flattened)

pixcorr = corrmean
print(pixcorr)

## SSIM

In [None]:
# see https://github.com/zijin-gu/meshconv-decoding/issues/3
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim

preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), 
])

# convert image to grayscale with rgb2grey
img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())
recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu())
print("converted, now calculating ssim...")

ssim_score=[]
for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):
    ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))

ssim = np.mean(ssim_score)
print(ssim)

### AlexNet

In [None]:
from torchvision.models import alexnet, AlexNet_Weights
alex_weights = AlexNet_Weights.IMAGENET1K_V1

alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)
alex_model.eval().requires_grad_(False)

# see alex_weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Lambda(lambda x: x.float()/255),    
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
# Ensure all_images and all_brain_recons are tensors on the correct device and are floating-point
all_images = all_images.to(device).float()  # Ensure conversion to float
all_brain_recons = all_brain_recons.to(device).float()  # Ensure conversion to float

layer = 'early, AlexNet(2)'
print(f"\n---{layer}---")
all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                          alex_model, preprocess, 'features.4')
alexnet2 = np.mean(all_per_correct)
print(f"2-way Percent Correct: {alexnet2:.4f}")

layer = 'mid, AlexNet(5)'
print(f"\n---{layer}---")
all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                          alex_model, preprocess, 'features.11')
alexnet5 = np.mean(all_per_correct)
print(f"2-way Percent Correct: {alexnet5:.4f}")

### InceptionV3

In [None]:
from torchvision.models import inception_v3, Inception_V3_Weights
weights = Inception_V3_Weights.DEFAULT
inception_model = create_feature_extractor(inception_v3(weights=weights), 
                                           return_nodes=['avgpool']).to(device)
inception_model.eval().requires_grad_(False)

# see weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Lambda(lambda x: x.float()/255.0),    
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])# Ensure all_images and all_brain_recons are tensors on the correct device and are floating-point
all_images = all_images.to(device).float()  # Ensure conversion to float
all_brain_recons = all_brain_recons.to(device).float()  # Ensure conversion to float

all_per_correct = two_way_identification(all_brain_recons, all_images,
                                        inception_model, preprocess, 'avgpool')
        
inception = np.mean(all_per_correct)
print(f"2-way Percent Correct: {inception:.4f}")

### CLIP

In [None]:
import clip
clip_model, preprocess = clip.load("ViT-L/14", device=device)

preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Lambda(lambda x: x.float()/255.0),    
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])
# Ensure all_images and all_brain_recons are tensors on the correct device and are floating-point
all_images = all_images.to(device).float()  # Ensure conversion to float
all_brain_recons = all_brain_recons.to(device).float()  # Ensure conversion to float

all_per_correct = two_way_identification(all_brain_recons, all_images,
                                        clip_model.encode_image, preprocess, None) # final layer
clip_ = np.mean(all_per_correct)
print(f"2-way Percent Correct: {clip_:.4f}")

### Efficient Net

In [None]:
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
weights = EfficientNet_B1_Weights.DEFAULT
eff_model = create_feature_extractor(efficientnet_b1(weights=weights), 
                                    return_nodes=['avgpool']).to(device)
eff_model.eval().requires_grad_(False)

# see weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


# Process all_images
gt = eff_model(torch.stack([preprocess(img.float()) for img in all_images.to(device)], dim=0))['avgpool']
gt = gt.reshape(len(gt),-1).cpu().numpy()

# Process all_brain_recons
fake = eff_model(torch.stack([preprocess(recon.float()) for recon in all_brain_recons.to(device)], dim=0))['avgpool']
fake = fake.reshape(len(fake),-1).cpu().numpy()

effnet = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()
print("Distance:",effnet)

### SwAV

In [None]:
# Ensure all_images and all_brain_recons are tensors on the correct device and are floating-point
all_images = all_images.to(device).float()  # Ensure conversion to float
all_brain_recons = all_brain_recons.to(device).float()  # Ensure conversion to float

In [None]:
swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
swav_model = create_feature_extractor(swav_model, 
                                    return_nodes=['avgpool']).to(device)
swav_model.eval().requires_grad_(False)

preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Lambda(lambda x: x.float()/255.0),    
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
# Ensure all_images and all_brain_recons are tensors on the correct device and are floating-point
all_images = all_images.to(device).float()  # Ensure conversion to float
all_brain_recons = all_brain_recons.to(device).float()  # Ensure conversion to float

gt = swav_model(preprocess(all_images))['avgpool']
gt = gt.reshape(len(gt),-1).cpu().numpy()
fake = swav_model(preprocess(all_brain_recons))['avgpool']
fake = fake.reshape(len(fake),-1).cpu().numpy()

swav = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()
print("Distance:",swav)

# Display in table

In [None]:
# Create a dictionary to store variable names and their corresponding values
data = {
    "Metric": ["PixCorr", "SSIM", "AlexNet(2)", "AlexNet(5)", "InceptionV3", "CLIP", "EffNet-B", "SwAV", "Classification"],
    "Value": [pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav, np.mean(result)],
}

df = pd.DataFrame(data)
print(df.to_string(index=False))

if not utils.is_interactive():
    # save table to txt file
    df.to_csv(f'{recon_path[:-3]}.csv', sep='\t', index=False)