In [1]:
import math
from collections import OrderedDict
import torch
from torch import nn
from torchvision import transforms
from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef
import torch.utils.model_zoo
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import torch.nn.functional as F
import time
from tqdm.auto import tqdm
import h5py
import random
import clip
from scipy.stats import spearmanr, pearsonr
from neurora.rdm_corr import rdm_correlation_spearman
import matplotlib.pyplot as plt
from neurora.stuff import clusterbased_permutation_1d_2sided
from scipy.stats import ttest_rel, ttest_1samp

device = 'cuda'

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [None]:
class Flatten(nn.Module):

    """
    Helper module for flattening input tensor to 1-D for the use in Linear modules
    """

    def forward(self, x):
        return x.view(x.size(0), -1)


class Identity(nn.Module):

    """
    Helper module that stores the current tensor. Useful for accessing by name
    """

    def forward(self, x):
        return x


class CORblock_S(nn.Module):

    scale = 4  # scale of the bottleneck convolution channels

    def __init__(self, in_channels, out_channels, times=1):
        super().__init__()

        self.times = times

        self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.skip = nn.Conv2d(out_channels, out_channels,
                              kernel_size=1, stride=2, bias=False)
        self.norm_skip = nn.BatchNorm2d(out_channels)

        self.conv1 = nn.Conv2d(out_channels, out_channels * self.scale,
                               kernel_size=1, bias=False)
        self.nonlin1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels * self.scale, out_channels * self.scale,
                               kernel_size=3, stride=2, padding=1, bias=False)
        self.nonlin2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(out_channels * self.scale, out_channels,
                               kernel_size=1, bias=False)
        self.nonlin3 = nn.ReLU(inplace=True)

        self.output = Identity()  # for an easy access to this block's output

        # need BatchNorm for each time step for training to work well
        for t in range(self.times):
            setattr(self, f'norm1_{t}', nn.BatchNorm2d(out_channels * self.scale))
            setattr(self, f'norm2_{t}', nn.BatchNorm2d(out_channels * self.scale))
            setattr(self, f'norm3_{t}', nn.BatchNorm2d(out_channels))

    def forward(self, inp):
        x = self.conv_input(inp)

        for t in range(self.times):
            if t == 0:
                skip = self.norm_skip(self.skip(x))
                self.conv2.stride = (2, 2)
            else:
                skip = x
                self.conv2.stride = (1, 1)

            x = self.conv1(x)
            x = getattr(self, f'norm1_{t}')(x)
            x = self.nonlin1(x)

            x = self.conv2(x)
            x = getattr(self, f'norm2_{t}')(x)
            x = self.nonlin2(x)

            x = self.conv3(x)
            x = getattr(self, f'norm3_{t}')(x)

            x += skip
            x = self.nonlin3(x)
            output = self.output(x)

        return output


def CORnet_S():
    model = nn.Sequential(OrderedDict([
        ('V1', nn.Sequential(OrderedDict([  # this one is custom to save GPU memory
            ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                            bias=False)),
            ('norm1', nn.BatchNorm2d(64)),
            ('nonlin1', nn.ReLU(inplace=True)),
            ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
            ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,
                            bias=False)),
            ('norm2', nn.BatchNorm2d(64)),
            ('nonlin2', nn.ReLU(inplace=True)),
            ('output', Identity())
        ]))),
        ('V2', CORblock_S(64, 128, times=2)),
        ('V4', CORblock_S(128, 256, times=4)),
        ('IT', CORblock_S(256, 512, times=2)),
        ('decoder', nn.Sequential(OrderedDict([
            ('avgpool', nn.AdaptiveAvgPool2d(1)),
            ('flatten', Flatten()),
            ('linear', nn.Linear(512, 1000)),
            ('output', Identity())
        ])))
    ]))

    # weight initialization
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        # nn.Linear is missing here because I originally forgot
        # to add it during the training of this network
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

    return model

class Encoder(nn.Module):
    def __init__(self, realnet, n_output):
        super(Encoder, self).__init__()
        
        # CORnet
        self.realnet = realnet
        
        # full connected layer
        self.fc_v1 = nn.Linear(200704, 128)
        self.fc_v2 = nn.Linear(100352, 128)
        self.fc_v4 = nn.Linear(50176, 128)
        self.fc_it = nn.Linear(25088, 128)
        self.fc = nn.Linear(512, n_output)
        self.activation = nn.ReLU()
        
    def forward(self, imgs):
        
        outputs = self.realnet(imgs)
        
        N = len(imgs)
        v1_outputs = self.realnet.module.V1(imgs) # N * 64 * 56 * 56
        v2_outputs = self.realnet.module.V2(v1_outputs) # N * 128 * 28 * 28
        v4_outputs = self.realnet.module.V4(v2_outputs) # N * 256 * 14 * 14
        it_outputs = self.realnet.module.IT(v4_outputs) # N * 512 * 7 * 7
        v1_features = self.fc_v1(v1_outputs.view(N, -1))
        v1_features = self.activation(v1_features)
        v2_features = self.fc_v2(v2_outputs.view(N, -1))
        v2_features = self.activation(v2_features)
        v4_features = self.fc_v4(v4_outputs.view(N, -1))
        v4_features = self.activation(v4_features)
        it_features = self.fc_it(it_outputs.view(N, -1))
        it_features = self.activation(it_features)
        features = torch.cat((v1_features, v2_features, v4_features, it_features), dim=1)
        features = self.fc(features)
        
        return outputs, features

def cal_rdm(v):
    n = v.size()[0]
    rdm = np.zeros([n, n])
    vec = []
    for i in range(n):
        for j in range(n):
            if i > j:
                #rdm[i, j] = 1-F.cosine_similarity(v[i], v[j], dim=0).item()
                rdm[i, j] = 1 - pearson_corrcoef(v[i], v[j]).item()
                rdm[j, i] = rdm[i, j]
    return rdm

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
)

### Model-EEG similarity

In [None]:
class Data4Model(torch.utils.data.Dataset):
    def __init__(self, state='training', sub_index=1, transform=None):
        
        super(Data4Model, self).__init__()
        
        imgs = np.load('GetData/'+state+'_imgpaths.npy').tolist()
        
        if state=='training':
            n = 16540
        else:
            n = 200
        
        mean = np.load('GetData/preprocessed_mean_overall.npy')
        std = np.load('GetData/preprocessed_std_overall.npy')
        eeg = np.load('preprocessed_eeg_data/sub-'+str(sub_index).zfill(2)+'_'+state+'.npy')
        eeg = (eeg-mean[sub_index-1])/std[sub_index-1]
        
        self.imgs = imgs
        self.eeg = eeg
        self.transform = transform
  
    def __len__(self):
        return len(self.imgs)
  
    def __getitem__(self, item):
        imgs = self.transform(Image.open(self.imgs[item]).convert('RGB'))
        eeg = torch.tensor(self.eeg[item]).float()
         
        return imgs, eeg

# RSA for CORnet
cornet_rdms = np.zeros([4, 200, 200])

test_dataset = Data4Model(state='test', sub_index=1, transform=transform)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    
cornet = CORnet_S().to(device)
cornet = torch.nn.DataParallel(cornet)
url = f'https://s3.amazonaws.com/cornet-models/cornet_s-1d3f7974.pth'
ckpt_data = torch.utils.model_zoo.load_url(url)
cornet.load_state_dict(ckpt_data['state_dict'])
    
cornet.eval()
    
v1 = torch.zeros([200, 200704])
v2 = torch.zeros([200, 100352])
v4 = torch.zeros([200, 50176])
it = torch.zeros([200, 25088])
    
index = 0
for imgs, eeg in test_data_loader:
        
    imgs = imgs.to(device)
    eeg = eeg.to(device)
        
    imgv1 = cornet.module.V1(imgs)
    imgv2 = cornet.module.V2(imgv1)
    imgv4 = cornet.module.V4(imgv2)
    imgit = cornet.module.IT(imgv4)
    v1[index] = imgv1.flatten()
    v2[index] = imgv2.flatten()
    v4[index] = imgv4.flatten()
    it[index] = imgit.flatten()
        
    index += 1
    
cornet_rdms[0] = cal_rdm(v1)
cornet_rdms[1] = cal_rdm(v2)
cornet_rdms[2] = cal_rdm(v4)
cornet_rdms[3] = cal_rdm(it)

np.save('RSA/CORnet_THINGS_EEG_test_rdms.npy', cornet_rdms)

cornet_corrs = np.zeros([10, 4, 60])

cornet_rdms = np.load('RSA/CORnet_THINGS_EEG_test_rdms.npy')

for sub in range(10):
    eegRDMs = np.load('RSA/THINGS_EEG_test/eegrdms_sub'+str(sub+1).zfill(2)+'.npy')[10:70]
    for t in range(60):
        eegRDM = eegRDMs[t]
        for i in range(4):
            cornet_rdm = cornet_rdms[i]
            cornet_corrs[sub, i, t] = rdm_correlation_spearman(eegRDM, cornet_rdm)[0]

np.save('RSA/CORnet_THINGS_EEG_test_RSA_corrs.npy', cornet_corrs)

# RSA for ReAlnets
realnet_rdms = np.zeros([10, 4, 200, 200])

for sub in range(10):
    
    test_dataset = Data4Model(state='test', sub_index=sub+1, transform=transform)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    
    realnet = CORnet_S().to(device)
    realnet = torch.nn.DataParallel(realnet)
    url = f'https://s3.amazonaws.com/cornet-models/cornet_s-1d3f7974.pth'
    ckpt_data = torch.utils.model_zoo.load_url(url)
    
    realnet.load_state_dict(ckpt_data['state_dict'])
    encoder = Encoder(realnet, 340).to(device)
    weights = torch.load('weights/ReAlnet/sub-'+str(sub+1).zfill(2)+'/encoder.pt')
    encoder.load_state_dict(weights)
    
    encoder.eval()
    
    v1 = torch.zeros([200, 200704])
    v2 = torch.zeros([200, 100352])
    v4 = torch.zeros([200, 50176])
    it = torch.zeros([200, 25088])
    
    index = 0
    for imgs, eeg in test_data_loader:
        
        imgs = imgs.to(device)
        eeg = eeg.to(device)
        
        imgv1 = encoder.realnet.module.V1(imgs)
        imgv2 = encoder.realnet.module.V2(imgv1)
        imgv4 = encoder.realnet.module.V4(imgv2)
        imgit = encoder.realnet.module.IT(imgv4)
        v1[index] = imgv1.flatten()
        v2[index] = imgv2.flatten()
        v4[index] = imgv4.flatten()
        it[index] = imgit.flatten()
        
        index += 1
    
    realnet_rdms[sub, 0] = cal_rdm(v1)
    realnet_rdms[sub, 1] = cal_rdm(v2)
    realnet_rdms[sub, 2] = cal_rdm(v4)
    realnet_rdms[sub, 3] = cal_rdm(it)
    
    print(sub+1)

np.save('RSA/ReAlnet_EEG_THINGS_EEG_test_rdms.npy', realnet_rdms)

realnet_corrs = np.zeros([10, 4, 60])

realnet_rdms = np.load('RSA/ReAlnet_EEG_THINGS_EEG_test_rdms.npy')

for sub in range(10):
    eegRDMs = np.load('RSA/THINGS_EEG_test/eegrdms_sub'+str(sub+1).zfill(2)+'.npy')[10:70]
    for t in range(60):
        eegRDM = eegRDMs[t]
        for i in range(4):
            realnet_rdm = realnet_rdms[sub, i]
            realnet_corrs[sub, i, t] = rdm_correlation_spearman(eegRDM, realnet_rdm)[0]

np.save('RSA/ReAlnet_EEG_THINGS_EEG_test_RSA_corrs.npy', realnet_corrs)

### Model-fMRI Similarity

In [None]:
class Data4Model_shen_fmri(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        
        super(Data4Model_shen_fmri, self).__init__()
        
        imgs = np.load('GetData/Shen_fMRI_test_imgpaths.npy').tolist()
        
        self.imgs = imgs
        self.transform = transform
  
    def __len__(self):
        return len(self.imgs)
  
    def __getitem__(self, item):
        imgs = self.transform(Image.open(self.imgs[item]).convert('RGB'))
         
        return imgs
    
# RSA for CORnet
cornet_rdms = np.zeros([5, 50, 50])

test_dataset = Data4Model_shen_fmri(transform=transform)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    
cornet = CORnet_S().to(device)
cornet = torch.nn.DataParallel(cornet)
url = f'https://s3.amazonaws.com/cornet-models/cornet_s-1d3f7974.pth'
ckpt_data = torch.utils.model_zoo.load_url(url)
cornet.load_state_dict(ckpt_data['state_dict'])
    
cornet.eval()
    
v1 = torch.zeros([50, 200704])
v2 = torch.zeros([50, 100352])
v4 = torch.zeros([50, 50176])
it = torch.zeros([50, 25088])
avgpool = torch.zeros([50, 512])
    
index = 0
for imgs in test_data_loader:
        
    imgs = imgs.to(device)
        
    imgv1 = cornet.module.V1(imgs)
    imgv2 = cornet.module.V2(imgv1)
    imgv4 = cornet.module.V4(imgv2)
    imgit = cornet.module.IT(imgv4)
    imgavgpool = cornet.module.decoder.avgpool(imgit)
    v1[index] = imgv1.flatten()
    v2[index] = imgv2.flatten()
    v4[index] = imgv4.flatten()
    it[index] = imgit.flatten()
    avgpool[index] = imgavgpool.flatten()
        
    index += 1
    
cornet_rdms[0] = cal_rdm(v1)
cornet_rdms[1] = cal_rdm(v2)
cornet_rdms[2] = cal_rdm(v4)
cornet_rdms[3] = cal_rdm(it)
cornet_rdms[4] = cal_rdm(avgpool)

print(cornet_rdms[0, :4, :4])

np.save('RSA/CORnet_Shen_fMRI_test_rdms.npy', cornet_rdms)

cornet_corrs = np.zeros([3, 5, 5])

cornet_rdms = np.load('RDMs/CORnet_Shen_fMRI_test_rdms.npy')

for i in range(3):
    for j in range(5):
        fmri_rdm = fmri_rdms[i, j]
        for k in range(5):
            cornet_rdm = cornet_rdms[k]
            cornet_corrs[i, j, k] = rdm_correlation_spearman(fmri_rdm, cornet_rdm)[0]

np.save('RSA/CORnet_Shen_fMRI_test_RSA_corrs.npy', cornet_corrs)

# RSA for ReAlnet
realnet_rdms = np.zeros([10, 5, 50, 50])
    
for sub in range(10):
    
    test_dataset = Data4Model_shen_fmri(transform=transform)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    
    realnet = CORnet_S().to(device)
    realnet = torch.nn.DataParallel(realnet)
    url = f'https://s3.amazonaws.com/cornet-models/cornet_s-1d3f7974.pth'
    ckpt_data = torch.utils.model_zoo.load_url(url)
    
    realnet.load_state_dict(ckpt_data['state_dict'])
    encoder = Encoder(realnet, 340).to(device)
    weights = torch.load('weights/ReAlnet/sub-'+str(sub+1).zfill(2)+'/encoder.pt')
    encoder.load_state_dict(weights)
    
    encoder.eval()
    
    v1 = torch.zeros([50, 200704])
    v2 = torch.zeros([50, 100352])
    v4 = torch.zeros([50, 50176])
    it = torch.zeros([50, 25088])
    avgpool = torch.zeros([50, 512])
    
    index = 0
    for imgs in test_data_loader:
        
        imgs = imgs.to(device)
        
        imgv1 = encoder.realnet.module.V1(imgs)
        imgv2 = encoder.realnet.module.V2(imgv1)
        imgv4 = encoder.realnet.module.V4(imgv2)
        imgit = encoder.realnet.module.IT(imgv4)
        imgavgpool = encoder.realnet.module.decoder.avgpool(imgit)
        v1[index] = imgv1.flatten()
        v2[index] = imgv2.flatten()
        v4[index] = imgv4.flatten()
        it[index] = imgit.flatten()
        avgpool[index] = imgavgpool.flatten()
        
        index += 1
    
    realnet_rdms[sub, 0] = cal_rdm(v1)
    realnet_rdms[sub, 1] = cal_rdm(v2)
    realnet_rdms[sub, 2] = cal_rdm(v4)
    realnet_rdms[sub, 3] = cal_rdm(it)
    realnet_rdms[sub, 4] = cal_rdm(avgpool)
    
    print(realnet_rdms[sub, 0, :4, :4])
    
    print(sub+1)

np.save('RDMs/ReAlnet_EEG_Shen_fMRI_test_rdms.npy', realnet_rdms)

realnet_corrs = np.zeros([3, 5, 10, 5])

realnet_rdms = np.load('RSA/ReAlnet_EEG_Shen_fMRI_test_rdms.npy')

for i in range(3):
    for j in range(5):
        fmri_rdm = fmri_rdms[i, j]
        
        for sub in range(10):
            for k in range(5):
                realnet_rdm = realnet_rdms[sub, k]
                realnet_corrs[i, j, sub, k] = rdm_correlation_spearman(fmri_rdm, realnet_rdm)[0]

np.save('RSA/ReAlnet_Shen_fMRI_test_RSA_corrs.npy', realnet_corrs)