In [1]:
print('importing modules')
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator, DeepSpeedPlugin

from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder
from models import GNet8_Encoder

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils

### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

accelerator = Accelerator(split_batches=False, mixed_precision="fp16") # ['no', 'fp8', 'fp16', 'bf16']

print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
num_devices = torch.cuda.device_count()
if num_devices==0 or not distributed: num_devices = 1
num_workers = num_devices
print(accelerator.state)

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
print = accelerator.print # only print if local_rank=0

importing modules


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


LOCAL RANK  0
PID of this process = 1915809
device: cuda
Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

distributed = False num_devices = 1 local rank = 0 world size = 1


In [2]:
# Load embedding model (last hidden layer)
try:
    print(clip_img_embedder)
except:
    clip_img_embedder = FrozenOpenCLIPImageEmbedder(
        arch="ViT-bigG-14",
        version="laion2b_s39b_b160k",
        output_tokens=True,
        only_tokens=True,
    )
    clip_img_embedder.to(device)
clip_seq_dim = 256
clip_emb_dim = 1664

## Load embedding model (last layer)
#     clip_img_embedder = FrozenOpenCLIPImageEmbedder(
#         arch="ViT-bigG-14",
#         version="laion2b_s39b_b160k",
#         output_tokens=False,
#         only_tokens=False,
#     )
#     clip_img_embedder.to(device)
# clip_seq_dim = 1
# clip_emb_dim = 1280

In [3]:
plot_all = False
compute_circular = False  # for the circular tests looking at image similarity in clip space (without any brain data involved)
saving = True

# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = f"sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0"
    eval_dir = f"/scratch/gpfs/ri4541/MindEyeV2/src/mindeyev2/evals/{model_name}"
    if ("remove" in model_name and "random" in model_name) or "ses-04" in model_name:
        all_recons_path = f"{eval_dir}/all_recons.pt"
    elif "paul" in model_name:
        all_recons_path = f"evals/{model_name}/{model_name}_all_recons.pt"
    else:
        all_recons_path = f"{eval_dir}/{model_name}_all_recons.pt" 

    data_path = "/scratch/gpfs/ri4541/MindEyeV2/src/mindeyev2"
    print("model_name:", model_name)

    jupyter_args = f"--model_name={model_name} --data_path={data_path} --all_recons_path={all_recons_path} --eval_dir={eval_dir}"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0
--model_name=sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0 --data_path=/scratch/gpfs/ri4541/MindEyeV2/src/mindeyev2 --all_recons_path=evals/sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0/sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0_all_recons.pt --eval_dir=/scratch/gpfs/ri4541/MindEyeV2/src/mindeyev2/evals/sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0


In [4]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default="/weka/proj-fmri/shared/mindeyev2_dataset",
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--all_recons_path", type=str,
    help="Path to where all_recons.pt is stored",
)

parser.add_argument(
    "--eval_dir", type=str,
    help="Path to where evaluations should be stored",
)

parser.add_argument(
    "--seed",type=int,default=42,
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

# Evals

In [5]:
if ("remove" in model_name and "random" in model_name) or "ses-04" in model_name:
    all_images = torch.load(f"{eval_dir}/all_images.pt")
    all_clipvoxels = torch.load(f"{eval_dir}/all_clipvoxels.pt")
    all_predcaptions = torch.load(f"{eval_dir}/all_predcaptions.pt")
    all_unrefinedrecons = torch.load(f"{eval_dir}/all_recons.pt")
elif "ses-01" in model_name and "paul" in model_name:
    all_images = torch.load(f"evals/{model_name}/{model_name}_all_images.pt")
    all_clipvoxels = torch.load(f"evals/{model_name}/{model_name}_all_clipvoxels.pt")
    all_predcaptions = torch.load(f"evals/{model_name}/{model_name}_all_predcaptions.pt")
    all_unrefinedrecons = torch.load(f"evals/{model_name}/{model_name}_all_recons.pt")
else:
    all_images = torch.load(f"{eval_dir}/{model_name}_all_images.pt") 
    all_clipvoxels = torch.load(f"{eval_dir}/{model_name}_all_clipvoxels.pt") 
    all_predcaptions = torch.load(f"{eval_dir}/{model_name}_all_predcaptions.pt") 
    all_unrefinedrecons = torch.load(f"{eval_dir}/{model_name}_all_recons.pt") 

print(all_images.shape)
print("all_recons_path:", all_recons_path)
all_recons = torch.load(all_recons_path)

# all_blurryrecons = torch.load(f"{eval_dir}/all_blurryrecons.pt")

torch.Size([100, 3, 256, 256])
all_recons_path: evals/sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0/sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0_all_recons.pt


In [6]:
# if "ses-01" in model_name:
#     paul_all_images = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_images.pt").to('cpu')
#     paul_all_clipvoxels = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_clipvoxels.pt").to('cpu')
#     paul_all_recons = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_recons.pt").to('cpu')
#     # paul_all_prior_out = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_prior_out.pt").to('cpu')
#     # all_images = torch.load(f"{eval_dir}/all_images.pt") 
#     print(paul_all_images.shape, all_images.shape)
#     print(paul_all_clipvoxels.shape, all_clipvoxels.shape)
#     print(torch.eq(paul_all_clipvoxels, all_clipvoxels))
#     # assert torch.allclose(paul_all_images, all_images)

In [7]:
# for i in range(100):
#     # print(torch.allclose(paul_all_images[i], all_images[i]))
#     pass

In [8]:
# num_images = paul_all_images.size(0)
# rows = 10  # Number of rows for the grid
# cols = 10  # Number of columns for the grid

# fig, axes = plt.subplots(rows, cols * 2, figsize=(80, 40))

# for i in range(num_images):
#     row = i // cols
#     col = (i % cols) * 2  # Adjust for side-by-side
    
#     # Plot correct image
#     ax_correct = axes[row, col]
#     ax_correct.imshow(paul_all_recons[i].permute(1, 2, 0).cpu().numpy())
#     ax_correct.axis('off')
#     ax_correct.set_title(f"Correct {i}")
    
#     # Plot modified image
#     ax_modified = axes[row, col + 1]
#     ax_modified.imshow(all_recons[i].permute(1, 2, 0).cpu().numpy())
#     ax_modified.axis('off')
#     ax_modified.set_title(f"Modified {i}")

# plt.tight_layout()
# plt.show()

In [9]:
model_name_plus_suffix = all_recons_path.split('/')[-1]
print(model_name_plus_suffix)
print(all_images.shape, all_recons.shape)

sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0_all_recons.pt
torch.Size([100, 3, 256, 256]) torch.Size([100, 3, 256, 256])


In [10]:
if "MST" in model_name:
    if ("remove" in model_name and "random" in model_name) or "ses-04" in model_name or "rishab" in model_name:
        MST_ID = np.load(f"{eval_dir}/MST_ID.npy")
        MST_pairmate_indices = np.load(f"{eval_dir}/MST_pairmate_indices.npy")
    elif "paul" in model_name:
        MST_ID = np.load(f"evals/{model_name}/{model_name}_MST_ID.npy")
        MST_pairmate_indices = np.array(utils.find_paired_indices(torch.Tensor(MST_ID)))
        # print(MST_pairmate_indices)
    else:
        MST_ID = np.load(f"{eval_dir}/{model_name}_MST_ID.npy") 
        MST_pairmate_indices = np.load(f"{eval_dir}/{model_name}_MST_pairmate_indices.npy") 

    # pairs = utils.find_paired_indices(torch.Tensor(MST_ID))
    # if "close_to_MST" in model_name or ("remove" in model_name and "random" in model_name) or "ses-0" in model_name:
    #     pairs = np.array(pairs[:-1])  # index out the placeholder
    # pairs = np.array(pairs)
    # if "ses-0" in model_name:
    #     if "ses-01" in model_name or "ses-04" in model_name:
    #         print(pairs.shape)
    #         assert pairs.shape == (49,2)
    #     else:
    #         assert pairs.shape == (50,3)
    # else:
    #     assert pairs.shape == (100,3)
    # print(pairs)
    # repeats_in_test = torch.load(f"{eval_dir}/repeats_in_test.pt")
    # test_image_indices = torch.load(f"{eval_dir}/test_image_indices.pt")
    all_unique_images = all_images[MST_pairmate_indices.flatten()]
    all_unique_clipvoxels = all_clipvoxels[MST_pairmate_indices.flatten()]

    print(model_name, MST_pairmate_indices.shape, all_unique_images.shape, all_unique_clipvoxels.shape)

sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0 (50, 2) torch.Size([100, 3, 256, 256]) torch.Size([100, 256, 1664])


In [11]:
# visualize all unique images
if plot_all:
    # Plot all the MST images and pairmates
    import textwrap
    def wrap_title(title, wrap_width):
        return "\n".join(textwrap.wrap(title, wrap_width))

    size = int(np.ceil(MST_pairmate_indices.shape[0]/2))  # helps determine size of plot
    fig, axes = plt.subplots(size, 4, figsize=(15, size*4))
    jj=-1; kk=0;
    for i, j in enumerate(all_unique_images):
        jj+=1
        axes[kk][jj].imshow(utils.torch_to_Image(j))
        axes[kk][jj].axis('off')
        if jj==3: 
            kk+=1; jj=-1

    fig.tight_layout()
    # plt.savefig('figures/MST_2_pairmates_10-01')
    plt.show()

In [12]:
# if plot_all:
#     # create full grid of recon comparisons
#     from PIL import Image

#     imsize = 150
#     if all_images.shape[-1] != imsize:
#         all_images = transforms.Resize((imsize,imsize))(all_images).float()
#     if all_recons.shape[-1] != imsize:
#         all_recons = transforms.Resize((imsize,imsize))(all_recons).float()

#     num_images = all_recons.shape[0]
#     num_rows = (2 * num_images + 9) // 10

#     # Interleave tensors
#     merged = torch.stack([val for pair in zip(all_images, all_recons) for val in pair], dim=0)

#     # Calculate grid size
#     grid = torch.zeros((num_rows * 10, 3, all_recons.shape[-1], all_recons.shape[-1]))

#     # Populate the grid
#     grid[:2*num_images] = merged
#     grid_images = [transforms.functional.to_pil_image(grid[i]) for i in range(num_rows * 10)]

#     # Create the grid image
#     grid_image = Image.new('RGB', (all_recons.shape[-1]*10, all_recons.shape[-1] * num_rows))  # 10 images wide

#     # Paste images into the grid
#     for i, img in enumerate(grid_images):
#         grid_image.paste(img, (all_recons.shape[-1] * (i % 10), all_recons.shape[-1] * (i // 10)))
#     grid_image
#     # grid_image.save(f"{model_name_plus_suffix[:-3]}_1000recons.png")

In [13]:
imsize = 256
if all_images.shape[-1] != imsize:
    all_images = transforms.Resize((imsize,imsize))(all_images).float()
if all_recons.shape[-1] != imsize:
    all_recons = transforms.Resize((imsize,imsize))(all_recons).float()
try:
    if all_blurryrecons.shape[-1] != imsize:
        all_blurryrecons = transforms.Resize((imsize,imsize))(all_blurryrecons).float()
except:
    pass

if "enhanced" in model_name_plus_suffix:
    try:
        all_recons = all_recons*.75 + all_blurryrecons*.25
        print("weighted averaging to improve low-level evals")
    except:
        pass

In [14]:
# visualize some images with recons and captions
if plot_all:
    assert np.all(all_images.shape == all_recons.shape)
    import textwrap
    def wrap_title(title, wrap_width):
        return "\n".join(textwrap.wrap(title, wrap_width))

    fig, axes = plt.subplots(3, 4, figsize=(10, 8))
    jj=-1; kk=0;
    for j in np.array([0,1,2,3,4,5]):
        jj+=1
        # print(kk,jj)
        axes[kk][jj].imshow(utils.torch_to_Image(all_images[j]))
        axes[kk][jj].axis('off')
        jj+=1
        axes[kk][jj].imshow(utils.torch_to_Image(all_recons[j]))
        axes[kk][jj].axis('off')
        axes[kk][jj].set_title(wrap_title(str(all_predcaptions[[j]]),wrap_width=30), fontsize=8)
        if jj==3: 
            kk+=1; jj=-1

    fig.tight_layout()
    # plt.savefig('figures/recon_09-26')
    plt.show()

# Retrieval eval (chance =  1/100)

In [15]:
from scipy import stats

all_fwd_acc = []
all_bwd_acc = []

assert len(all_unique_images) == len(all_unique_clipvoxels)  

all_percent_correct_fwds, all_percent_correct_bwds = [], []

with torch.cuda.amp.autocast(dtype=torch.float16):
    all_emb = clip_img_embedder(all_unique_images.to(device)).float() # CLIP-Image
    all_emb_ = all_unique_clipvoxels # CLIP-Brain

    # flatten if necessary
    all_emb = all_emb.reshape(len(all_emb),-1).to(device)
    all_emb_ = all_emb_.reshape(len(all_emb_),-1).to(device)

    # l2norm 
    all_emb = nn.functional.normalize(all_emb,dim=-1)
    all_emb_ = nn.functional.normalize(all_emb_,dim=-1)

    all_labels = torch.arange(len(all_emb)).to(device)
    all_bwd_sim = utils.batchwise_cosine_similarity(all_emb,all_emb_)  # clip, brain
    all_fwd_sim = utils.batchwise_cosine_similarity(all_emb_,all_emb)  # brain, clip

    if "ses-0" not in model_name or "ses-01" in model_name or "ses-04" in model_name:
        assert len(all_fwd_sim) == 100
        assert len(all_bwd_sim) == 100
    else:
        assert len(all_fwd_sim) == 50
        assert len(all_bwd_sim) == 50
    
    all_percent_correct_fwds = utils.topk(all_fwd_sim, all_labels, k=1).item()
    all_percent_correct_bwds = utils.topk(all_bwd_sim, all_labels, k=1).item()

all_fwd_acc.append(all_percent_correct_fwds)
all_bwd_acc.append(all_percent_correct_bwds)

all_fwd_sim = np.array(all_fwd_sim.cpu())
all_bwd_sim = np.array(all_bwd_sim.cpu())

print(f"overall fwd percent_correct: {all_fwd_acc[0]:.4f}")
print(f"overall bwd percent_correct: {all_bwd_acc[0]:.4f}")

overall fwd percent_correct: 0.7100
overall bwd percent_correct: 0.7000


In [16]:
if "ses-0" not in model_name:
    from scipy import stats

    fwd_acc = []
    bwd_acc = []
    fwd_sim_halves = []
    bwd_sim_halves = []

    assert len(all_unique_images) == len(all_unique_clipvoxels)  

    for i in range(2):  # since this is a 2-session model, we evaluate on the test set corresponding to each session and report both separately for better comparison to single-session models
        percent_correct_fwds, percent_correct_bwds = [], []
        # percent_correct_fwd, percent_correct_bwd = None, None

        if i==0:  
            all_unique_images_half = all_unique_images[:int(len(all_unique_images)/2)]
            all_unique_clipvoxels_half = all_unique_clipvoxels[:int(len(all_unique_clipvoxels)/2)]
        elif i==1:
            all_unique_images_half = all_unique_images[int(len(all_unique_images)/2):]
            all_unique_clipvoxels_half = all_unique_clipvoxels[int(len(all_unique_clipvoxels)/2):]


        with torch.cuda.amp.autocast(dtype=torch.float16):
            emb = clip_img_embedder(all_unique_images_half.to(device)).float() # CLIP-Image
            emb_ = all_unique_clipvoxels_half # CLIP-Brain

            # flatten if necessary
            emb = emb.reshape(len(emb),-1).to(device)
            emb_ = emb_.reshape(len(emb_),-1).to(device)

            # l2norm 
            emb = nn.functional.normalize(emb,dim=-1)
            emb_ = nn.functional.normalize(emb_,dim=-1)

            labels = torch.arange(len(emb)).to(device)
            bwd_sim = utils.batchwise_cosine_similarity(emb,emb_)  # clip, brain
            fwd_sim = utils.batchwise_cosine_similarity(emb_,emb)  # brain, clip

            assert len(fwd_sim) == 50
            assert len(bwd_sim) == 50

            # percent_correct_fwds = np.append(percent_correct_fwds, utils.topk(fwd_sim, labels, k=1).item())
            # percent_correct_bwds = np.append(percent_correct_bwds, utils.topk(bwd_sim, labels, k=1).item())
            percent_correct_fwds = utils.topk(fwd_sim, labels, k=1).item()
            percent_correct_bwds = utils.topk(bwd_sim, labels, k=1).item()

        # percent_correct_fwd = np.mean(percent_correct_fwds)
        # fwd_sd = np.std(percent_correct_fwds) / np.sqrt(len(percent_correct_fwds))
        # fwd_ci = stats.norm.interval(0.95, loc=percent_correct_fwd, scale=fwd_sd)

        # percent_correct_bwd = np.mean(percent_correct_bwds)
        # bwd_sd = np.std(percent_correct_bwds) / np.sqrt(len(percent_correct_bwds))
        # bwd_ci = stats.norm.interval(0.95, loc=percent_correct_bwd, scale=bwd_sd)

        fwd_acc.append(percent_correct_fwds)
        bwd_acc.append(percent_correct_bwds)

        fwd_sim = np.array(fwd_sim.cpu())
        bwd_sim = np.array(bwd_sim.cpu())
        fwd_sim_halves.append(fwd_sim)
        bwd_sim_halves.append(bwd_sim)

    print(f"ses-02 fwd percent_correct: {fwd_acc[0]:.4f}; ses-03 fwd percent_correct: {fwd_acc[1]:.4f}")
    print(f"ses-02 bwd percent_correct: {bwd_acc[0]:.4f}; ses-03 bwd percent_correct: {bwd_acc[1]:.4f} ")

In [17]:
if compute_circular:
    if "ses-0" not in model_name:  # we're in a multisession model, assumed ses-02 and ses-03 for now
        fwd_acc_circular = []
        fwd_sim_halves_circular = []

        assert len(all_unique_images) == len(all_unique_clipvoxels)  

        for i in range(2):  # since this is a 2-session model, we evaluate on the test set corresponding to each session and report both separately for better comparison to single-session models
            percent_correct_fwds_circular = []
            # percent_correct_fwd_circular = None

            if i==0:  
                all_unique_images_half_circular = all_unique_images[:int(len(all_unique_images)/2)]
                all_unique_clipvoxels_half_circular = all_unique_clipvoxels[:int(len(all_unique_clipvoxels)/2)]
            elif i==1:
                all_unique_images_half_circular = all_unique_images[int(len(all_unique_images)/2):]
                all_unique_clipvoxels_half_circular = all_unique_clipvoxels[int(len(all_unique_clipvoxels)/2):]

            with torch.cuda.amp.autocast(dtype=torch.float16):
                emb_circular = clip_img_embedder(all_unique_images_half_circular.to(device)).float() # CLIP-Image

                # flatten if necessary
                emb_circular = emb_circular.reshape(len(emb_circular),-1).to(device)

                # l2norm 
                emb_circular = nn.functional.normalize(emb_circular,dim=-1)

                labels_circular = torch.arange(len(emb_circular)).to(device)
                fwd_sim_circular = utils.batchwise_cosine_similarity(emb_circular,emb_circular)  # clip, clip


                if "ses-0" in model_name:
                    assert len(fwd_sim_circular) == 25
                else:
                    assert len(fwd_sim_circular) == 50

                # percent_correct_fwds_circular = np.append(percent_correct_fwds_circular, utils.topk(fwd_sim_circular, labels_circular, k=1).item())
                percent_correct_fwds_circular = utils.topk(fwd_sim_circular, labels_circular, k=1).item()


            # percent_correct_fwd_circular = np.mean(percent_correct_fwds_circular)
            # fwd_sd_circular = np.std(percent_correct_fwds_circular) / np.sqrt(len(percent_correct_fwds_circular))
            # fwd_ci_circular = stats.norm.interval(0.95, loc=percent_correct_fwd_circular, scale=fwd_sd_circular)

            fwd_acc_circular.append(percent_correct_fwds_circular)

            fwd_sim_circular = np.array(fwd_sim_circular.cpu())
            fwd_sim_halves_circular.append(fwd_sim_circular)

        print(f"ses-02 fwd percent_correct: {fwd_acc_circular[0]:.4f}; ses-03 fwd percent_correct: {fwd_acc_circular[1]:.4f}")
    
    else:  # single session model
        fwd_acc_circular = []

        assert len(all_unique_images) == len(all_unique_clipvoxels)  

        percent_correct_fwds_circular = []
        # percent_correct_fwd_circular = None

        with torch.cuda.amp.autocast(dtype=torch.float16):
            emb_circular = clip_img_embedder(all_unique_images.to(device)).float() # CLIP-Image

            # flatten if necessary
            emb_circular = emb_circular.reshape(len(emb_circular),-1).to(device)

            # l2norm 
            emb_circular = nn.functional.normalize(emb_circular,dim=-1)

            labels_circular = torch.arange(len(emb_circular)).to(device)
            fwd_sim_circular = utils.batchwise_cosine_similarity(emb_circular,emb_circular)  # clip, clip


            if "ses-01" in model_name:
                assert len(fwd_sim_circular) == 100
            else:
                assert len(fwd_sim_circular) == 50

            # percent_correct_fwds_circular = np.append(percent_correct_fwds_circular, utils.topk(fwd_sim_circular, labels_circular, k=1).item())
            percent_correct_fwds_circular = utils.topk(fwd_sim_circular, labels_circular, k=1).item()


        # percent_correct_fwd_circular = np.mean(percent_correct_fwds_circular)
        # fwd_sd_circular = np.std(percent_correct_fwds_circular) / np.sqrt(len(percent_correct_fwds_circular))
        # fwd_ci_circular = stats.norm.interval(0.95, loc=percent_correct_fwd_circular, scale=fwd_sd_circular)

        fwd_acc_circular.append(percent_correct_fwds_circular)

        fwd_sim_circular = np.array(fwd_sim_circular.cpu())

        print(f"session fwd percent_correct (circular): {fwd_acc_circular[0]:.4f}")

In [18]:
if compute_circular:
    # print(utils.topk(torch.Tensor(fwd_sim_halves[1]).to(device), labels, k=1).item())
    # ses02_top1 = torch.argsort(torch.Tensor(fwd_sim_halves[0]).to(device),axis=1)[:,-1] == labels  # from utils.topk()
    # ses03_top1 = torch.argsort(torch.Tensor(fwd_sim_halves[1]).to(device),axis=1)[:,-1] == labels
    # top1_results = torch.cat((ses02_top1, ses03_top1))
    # incorrect_idx = torch.argwhere(top1_results == False)[:,0]
    # print(incorrect_idx)

    # confirm that the lure is behind the target 80% of the time in CLIP last hidden layer embeddings

    use_fwd_sim = False
    use_first_half = False

    all_top_sims = []  # len(fwd_sim_halves[0]); for each image, contains the similarity to the top-n choices until it gets the correct answer. If len 1, top-1 is correct
    all_pairmate_sims = []  # len(fwd_sim_halves[0]); the similarity of each image to its pairmate 
    all_chose_lures = []  # len(fwd_sim_halves[0]); True for each top-n choice if the lure was predicted to be more similar to the target 
    if "ses-0" not in model_name:
        sim_halves = fwd_sim_halves if use_fwd_sim else bwd_sim_halves
        sim_halves = sim_halves[0] if use_first_half else sim_halves[1]
    else:
        sim_halves = all_fwd_sim if use_fwd_sim else all_bwd_sim
    for i, img in enumerate(sim_halves):
        if i%2==0:
            idx_to_pairmate = 1
        elif i%2==1:
            idx_to_pairmate = -1

        order = img.argsort()[::-1]
        # print(order)
        top_sim = []
        chose_lure = []
        for idx in order:
            sim = img[idx]
            pairmate_sim = img[i+idx_to_pairmate]
            top_sim.append(sim) 
            chose_lure.append((idx, sim <= pairmate_sim))
            # print(i, idx, img[idx], img[i+idx_to_pairmate])
            if idx == i:
                break

        all_top_sims.append(top_sim)
        all_pairmate_sims.append(pairmate_sim)
        all_chose_lures.append(chose_lure)

    # print(all_top_sims)
    # print()
    # print(all_pairmate_sims)
    # print()
    # print(all_chose_lures)

    where_chose_pairmate = []
    for idx, i in enumerate(all_chose_lures):
        for value in i:
            # print(value[1])
            if value[1] == True:
                # print(idx, i, end='\n')
                where_chose_pairmate.append(idx)
                break

    # where_chose_pairmate  # trials where the pairmate was chosen ahead of the target

In [19]:
# top-n predictions using CLIP brain embeddings

if plot_all:
    use_fwd_sim = True
    top_n = 10  # how many of the top n images to display
    print("Given Brain embedding, find correct Image embedding")
    fig, ax = plt.subplots(nrows=len(all_unique_images), ncols=top_n+1, figsize=(top_n*2,len(all_unique_images)*2))
    for trial in range(len(all_unique_images)):
        ax[trial, 0].imshow(utils.torch_to_Image(all_unique_images[trial]))
        ax[trial, 0].set_title("original\nimage")
        ax[trial, 0].axis("off")
        for attempt in range(top_n):
            if trial < 50:
                if "ses-0" not in model_name:
                    sim_half = fwd_sim_halves[0] if use_fwd_sim else bwd_sim_halves[0]
                    unique_imgs_to_plot = all_unique_images[:int(len(all_unique_images)/2)]
                    # unique_clipvoxels_to_plot = all_unique_clipvoxels[:int(len(all_unique_clipvoxels)/2)]
                else:
                    sim_half = all_fwd_sim if use_fwd_sim else all_bwd_sim
                    unique_imgs_to_plot = all_unique_images
                    # unique_clipvoxels_to_plot = all_unique_clipvoxels
                which = np.flip(np.argsort(sim_half[trial]))[attempt]

            elif trial >= 50:
                if "ses-0" not in model_name:
                    sim_halves = fwd_sim_halves[1] if use_fwd_sim else bwd_sim_halves[1]
                    unique_imgs_to_plot = all_unique_images[int(len(all_unique_images)/2):]
                    # unique_clipvoxels_to_plot = all_unique_clipvoxels[int(len(all_unique_clipvoxels)/2):]
                else:
                    sim_halves = all_fwd_sim if use_fwd_sim else all_bwd_sim
                    unique_imgs_to_plot = all_unique_images
                    # unique_clipvoxels_to_plot = all_unique_clipvoxels
                which = np.flip(np.argsort(sim_half[trial-50]))[attempt]

            ax[trial, attempt+1].imshow(utils.torch_to_Image(unique_imgs_to_plot[which]))
            ax[trial, attempt+1].set_title(f"Top {attempt+1}")
            ax[trial, attempt+1].axis("off")
    fig.tight_layout()
    # plt.savefig('figures/retrieval_top10')
    plt.show()

In [20]:
# similarity of each unique MST image to all others using CLIP image embeddings only (top-1 is guaranteed to be correct)
# uses last hidden layer (which may not match as well as the last layer to human semantic judgments)
if plot_all and compute_circular:
    print("Given Brain embedding, find correct Image embedding")
    top_n = 10  # how many of the top n images to display
    fig, ax = plt.subplots(nrows=len(all_unique_images), ncols=top_n+1, figsize=(top_n*2,len(all_unique_images)*2))
    for trial in range(len(all_unique_images)):
        ax[trial, 0].imshow(utils.torch_to_Image(all_unique_images[trial]))
        ax[trial, 0].set_title("original\nimage")
        ax[trial, 0].axis("off")
        for attempt in range(10):
            if trial < 50:
                if "ses-0" not in model_name:
                    sim_half_circular = fwd_sim_halves_circular[0]
                    unique_imgs_to_plot_circular = all_unique_images[:int(len(all_unique_images)/2)]
                else:
                    sim_half_circular = fwd_sim_circular
                    unique_imgs_to_plot_circular = all_unique_images
                which_circular = np.flip(np.argsort(sim_half_circular[trial]))[attempt]

            elif trial >= 50:
                if "ses-0" not in model_name:
                    sim_halves_circular = fwd_sim_halves_circular[1]
                    unique_imgs_to_plot_circular = all_unique_images[int(len(all_unique_images)/2):]
                else:
                    sim_halves_circular = all_fwd_sim_circular
                    unique_imgs_to_plot_circular = all_unique_images
                which_circular = np.flip(np.argsort(sim_half_circular[trial-50]))[attempt]

            ax[trial, attempt+1].imshow(utils.torch_to_Image(unique_imgs_to_plot_circular[which_circular]))
            ax[trial, attempt+1].set_title(f"Top {attempt+1}")
            ax[trial, attempt+1].axis("off")
    fig.tight_layout()
    # plt.savefig('figures/circular_top10')
    plt.show()

## MST Paired Retrieval (chance = 50%)

In [21]:
if compute_circular:
    all_top_sims_circular = []  # len(fwd_sim_halves[0]); for each image, contains the similarity to the top-n choices until it gets the correct answer. If len 1, top-1 is correct
    all_pairmate_sims_circular = []  # len(fwd_sim_halves[0]); the similarity of each image to its pairmate 
    all_chose_lures_circular = []  # len(fwd_sim_halves[0]); True for each top-n choice if the lure was predicted to be more similar to the target 

    if "ses-0" not in model_name:
        first_half = True
        sim_halves_circular = fwd_sim_halves_circular
        sim_halves_circular = sim_halves_circular[0] if use_first_half else sim_halves_circular[1]
    else:
        sim_halves_circular = all_fwd_sim_circular
    for i, img in enumerate(sim_halves_circular):
        if i%2==0:
            idx_to_pairmate = 1
        elif i%2==1:
            idx_to_pairmate = -1

        order_circular = img.argsort()[::-1]
        # print(order)
        top_sim_circular = []
        chose_lure_circular = []
        for idx in order_circular:
            sim_circular = img[idx]
            pairmate_sim_circular = img[i+idx_to_pairmate]
            top_sim_circular.append(sim_circular) 
            chose_lure_circular.append((idx, sim_circular <= pairmate_sim_circular))
            # print(i, idx, img[idx], img[i+idx_to_pairmate])
            if idx == i:
                break

        all_top_sims_circular.append(top_sim_circular)
        all_pairmate_sims_circular.append(pairmate_sim_circular)
        all_chose_lures_circular.append(chose_lure_circular)

    bot_half = (all_pairmate_sims_circular < np.median(all_pairmate_sims_circular))[::2]  # every other one since the sims are symmetric
    top_half = (all_pairmate_sims_circular > np.median(all_pairmate_sims_circular))[::2]

    binary_acc = []
    for i,(a,b) in enumerate(tqdm(MST_pairmate_indices,total=len(MST_pairmate_indices))):
        # print(i,a,b)
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                emb_a = nn.functional.normalize(clip_img_embedder(all_images[[a]].to(device)).float().flatten(1),dim=-1)
                emb_b = nn.functional.normalize(clip_img_embedder(all_images[[b]].to(device)).float().flatten(1),dim=-1)
                emb_v = nn.functional.normalize(all_clipvoxels[[a]].flatten(1),dim=-1).to(device)

                a_sim = utils.pairwise_cosine_similarity(emb_v, emb_a).item()
                b_sim = utils.pairwise_cosine_similarity(emb_v, emb_b).item()

                binary_acc.append(a_sim > b_sim)

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                emb_a = nn.functional.normalize(clip_img_embedder(all_images[[a]].to(device)).float().flatten(1),dim=-1)
                emb_b = nn.functional.normalize(clip_img_embedder(all_images[[b]].to(device)).float().flatten(1),dim=-1)
                emb_v = nn.functional.normalize(all_clipvoxels[[b]].flatten(1),dim=-1).to(device)

                a_sim = utils.pairwise_cosine_similarity(emb_v, emb_a).item()
                b_sim = utils.pairwise_cosine_similarity(emb_v, emb_b).item()

                binary_acc.append(a_sim < b_sim)

    assert len(binary_acc) == 50
    mst_score = np.mean(binary_acc)
    print(f"session score: {np.mean(binary_acc):.4f} ± {np.std(binary_acc):.4f}")


In [22]:
# test = np.sort(MST_pairmate_indices, axis=1)
# test

In [23]:
# model_name

In [24]:
# paul_all_images = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_images.pt").to('cpu')
# paul_all_clipvoxels = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_clipvoxels.pt").to('cpu')
# paul_all_recons = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_recons.pt").to('cpu')
# paul_all_prior_out = torch.load(f"evals/sub-001_ses-01_bs24_MST_paul_MSTsplit/sub-001_ses-01_bs24_MST_paul_MSTsplit_all_prior_out.pt").to('cpu')
# print(paul_all_images.shape, all_images.shape)
# # assert torch.eq(paul_all_images, all_images)

In [25]:
# print(paul_all_images.shape, all_images.shape)
# torch.eq(paul_all_images, all_images)

In [26]:
if "MST" in model_name:
    if "ses-0" not in model_name:
        print('assuming ses-02+ses-03 multisession model')
        mst_score = []
        for half in range(2):
            binary_acc = []
            if half==0:
                MST_pairmate_indices_half = MST_pairmate_indices[:int(len(MST_pairmate_indices)/2)]
            elif half==1:
                MST_pairmate_indices_half = MST_pairmate_indices[int(len(MST_pairmate_indices)/2):]
            for i,(a,b) in enumerate(tqdm(MST_pairmate_indices_half,total=len(MST_pairmate_indices_half))):
                # print(i,a,b)
                with torch.no_grad():
                    with torch.cuda.amp.autocast():
                        emb_a = nn.functional.normalize(clip_img_embedder(all_images[[a]].to(device)).float().flatten(1),dim=-1)
                        emb_b = nn.functional.normalize(clip_img_embedder(all_images[[b]].to(device)).float().flatten(1),dim=-1)
                        emb_v = nn.functional.normalize(all_clipvoxels[[a]].flatten(1),dim=-1).to(device)

                        a_sim = utils.pairwise_cosine_similarity(emb_v, emb_a).item()
                        b_sim = utils.pairwise_cosine_similarity(emb_v, emb_b).item()

                        binary_acc.append(a_sim > b_sim)

                with torch.no_grad():
                    with torch.cuda.amp.autocast():
                        emb_a = nn.functional.normalize(clip_img_embedder(all_images[[a]].to(device)).float().flatten(1),dim=-1)
                        emb_b = nn.functional.normalize(clip_img_embedder(all_images[[b]].to(device)).float().flatten(1),dim=-1)
                        emb_v = nn.functional.normalize(all_clipvoxels[[b]].flatten(1),dim=-1).to(device)

                        a_sim = utils.pairwise_cosine_similarity(emb_v, emb_a).item()
                        b_sim = utils.pairwise_cosine_similarity(emb_v, emb_b).item()

                        binary_acc.append(a_sim < b_sim)

            assert len(binary_acc) == 50  # don't want to average across both sessions; make sure it resets
            print(f"ses-0{half+2} score: {np.mean(binary_acc):.4f} ± {np.std(binary_acc):.4f}")
            mst_score.append((np.mean(binary_acc),np.std(binary_acc)))

        # print(mst_score)
    else:
        print('assuming single session')
        binary_acc = []
        for i,(a,b) in enumerate(tqdm(MST_pairmate_indices,total=len(MST_pairmate_indices))):
            # print(i,a,b)
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    emb_a = nn.functional.normalize(clip_img_embedder(all_images[[a]].to(device)).float().flatten(1),dim=-1)
                    emb_b = nn.functional.normalize(clip_img_embedder(all_images[[b]].to(device)).float().flatten(1),dim=-1)
                    emb_v = nn.functional.normalize(all_clipvoxels[[a]].flatten(1),dim=-1).to(device)

                    a_sim = utils.pairwise_cosine_similarity(emb_v, emb_a).item()
                    b_sim = utils.pairwise_cosine_similarity(emb_v, emb_b).item()

                    binary_acc.append(a_sim > b_sim)

            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    emb_a = nn.functional.normalize(clip_img_embedder(all_images[[a]].to(device)).float().flatten(1),dim=-1)
                    emb_b = nn.functional.normalize(clip_img_embedder(all_images[[b]].to(device)).float().flatten(1),dim=-1)
                    emb_v = nn.functional.normalize(all_clipvoxels[[b]].flatten(1),dim=-1).to(device)

                    a_sim = utils.pairwise_cosine_similarity(emb_v, emb_a).item()
                    b_sim = utils.pairwise_cosine_similarity(emb_v, emb_b).item()

                    binary_acc.append(a_sim < b_sim)
                    
        # assert len(binary_acc) == 50
        mst_score = np.mean(binary_acc)
        print(f"session score: {np.mean(binary_acc):.4f} ± {np.std(binary_acc):.4f}")


assuming single session


100%|██████████| 50/50 [00:04<00:00, 10.23it/s]

session score: 0.8900 ± 0.3129





## 2-way identification

In [27]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

@torch.no_grad()
def two_way_identification(all_recons, all_images, model, preprocess, feature_layer=None, return_avg=True):
    preds = model(torch.stack([preprocess(recon) for recon in all_recons], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(all_images), len(all_images):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    if return_avg:
        perf = np.mean(success_cnt) / (len(all_images)-1)
        return perf
    else:
        return success_cnt, len(all_images)-1
    
all_recons=all_recons.to(device)
all_images=all_images.to(device)

## PixCorr

In [28]:
preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])

# Flatten images while keeping the batch dimension
all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()
all_recons_flattened = preprocess(all_recons).view(len(all_recons), -1).cpu()

print(all_images_flattened.shape)
print(all_recons_flattened.shape)

corr_stack = []

corrsum = 0
for i in tqdm(range(len(all_images))):
    corrcoef = np.corrcoef(all_images_flattened[i], all_recons_flattened[i])[0][1]
    if np.isnan(corrcoef):
        print("WARNING: CORRCOEF WAS NAN")
        corrcoef = 0
    corrsum += corrcoef
    corr_stack.append(corrcoef)
corrmean = corrsum / len(all_images)

pixcorr = corrmean
print(pixcorr)



torch.Size([100, 541875])
torch.Size([100, 541875])


100%|██████████| 100/100 [00:00<00:00, 775.96it/s]

0.19027211547192407





In [29]:
# print(all_images.shape)
# print(all_images_flattened.shape)
# print(all_recons.shape)
# print(all_recons_flattened.shape)
# len(all_images)

## SSIM

In [30]:
# see https://github.com/zijin-gu/meshconv-decoding/issues/3
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim

preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), 
])

# convert image to grayscale with rgb2grey
img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())
recon_gray = rgb2gray(preprocess(all_recons).permute((0,2,3,1)).cpu())
print("converted, now calculating ssim...")

ssim_score=[]
for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):
    ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))

ssim = np.mean(ssim_score)
print(ssim)



converted, now calculating ssim...


100%|██████████| 100/100 [00:00<00:00, 120.46it/s]

0.3447461015516486





## AlexNet

In [31]:
from torchvision.models import alexnet, AlexNet_Weights
alex_weights = AlexNet_Weights.IMAGENET1K_V1

alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)
alex_model.eval().requires_grad_(False).to(device)

# see alex_weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

layer = 'early, AlexNet(2)'
print(f"\n---{layer}---")
all_per_correct = two_way_identification(all_recons, all_images, 
                                                          alex_model, preprocess, 'features.4')
alexnet2 = np.mean(all_per_correct)
print(f"2-way Percent Correct: {alexnet2:.4f}")

layer = 'mid, AlexNet(5)'
print(f"\n---{layer}---")
all_per_correct = two_way_identification(all_recons, all_images, 
                                                          alex_model, preprocess, 'features.11')
alexnet5 = np.mean(all_per_correct)
print(f"2-way Percent Correct: {alexnet5:.4f}")

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,



---early, AlexNet(2)---
2-way Percent Correct: 0.8010

---mid, AlexNet(5)---
2-way Percent Correct: 0.8542


## InceptionV3

In [32]:
from torchvision.models import inception_v3, Inception_V3_Weights
weights = Inception_V3_Weights.DEFAULT
inception_model = create_feature_extractor(inception_v3(weights=weights), 
                                           return_nodes=['avgpool']).to(device)
inception_model.eval().requires_grad_(False).to(device)

# see weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

all_per_correct = two_way_identification(all_recons, all_images,
                                        inception_model, preprocess, 'avgpool')
        
inception = np.mean(all_per_correct)
print(f"2-way Percent Correct: {inception:.4f}")



2-way Percent Correct: 0.7205


## CLIP

In [33]:
import clip
clip_model, preprocess = clip.load("ViT-L/14", device=device)

preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])

all_per_correct = two_way_identification(all_recons, all_images,
                                        clip_model.encode_image, preprocess, None) # final layer
clip_ = np.mean(all_per_correct)
print(f"2-way Percent Correct: {clip_:.4f}")

2-way Percent Correct: 0.7387


## Efficient Net

In [34]:
import scipy as sp
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
weights = EfficientNet_B1_Weights.DEFAULT
eff_model = create_feature_extractor(efficientnet_b1(weights=weights), 
                                    return_nodes=['avgpool'])
eff_model.eval().requires_grad_(False).to(device)

# see weights.transforms()
preprocess = transforms.Compose([
    transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

gt = eff_model(preprocess(all_images))['avgpool']
gt = gt.reshape(len(gt),-1).cpu().numpy()
fake = eff_model(preprocess(all_recons))['avgpool']
fake = fake.reshape(len(fake),-1).cpu().numpy()

effnet_nomean = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))])
effnet = effnet_nomean.mean()
print("Distance:",effnet)

Distance: 0.8597363629795817


## SwAV

In [35]:
swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
swav_model = create_feature_extractor(swav_model, 
                                    return_nodes=['avgpool'])
swav_model.eval().requires_grad_(False).to(device)

preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

gt = swav_model(preprocess(all_images))['avgpool']
gt = gt.reshape(len(gt),-1).cpu().numpy()
fake = swav_model(preprocess(all_recons))['avgpool']
fake = fake.reshape(len(fake),-1).cpu().numpy()

swav_nomean = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))])
swav = swav_nomean.mean()
print("Distance:",swav)

Using cache found in /home/ri4541/.cache/torch/hub/facebookresearch_swav_main


Distance: 0.5120611452768613


In [36]:
#[pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav, percent_correct_fwd, percent_correct_bwd]
import pandas as pd
# pd.options.display.float_format = '{:.2%}'.format
# pd.reset_option('all')
if "ses-0" not in model_name:
    df = pd.DataFrame(["metrics", alexnet2, alexnet5, inception, clip_, effnet, swav, fwd_acc, bwd_acc, mst_score]).to_string(index=False)
else:
    df = pd.DataFrame(["metrics", alexnet2, alexnet5, inception, clip_, effnet, swav, all_fwd_acc[0], all_bwd_acc[0], mst_score]).to_string(index=False)
print(df)
# print(model_name_plus_suffix)
final_evals_path = f"{eval_dir}/final_evals"
if saving:
    with open(final_evals_path, 'w') as f:
        f.write(df)
    print('saved final evals!')

        0
  metrics
  0.80101
 0.854242
 0.720505
 0.738687
 0.859736
 0.512061
     0.71
      0.7
     0.89


FileNotFoundError: [Errno 2] No such file or directory: '/scratch/gpfs/ri4541/MindEyeV2/src/mindeyev2/evals/sub-001_ses-01_bs24_MST_paul_MSTsplit_random_seed_0/final_evals'

In [None]:
with open(final_evals_path, 'r') as f:
    for line in f:
        print(line, end='')
