In [1]:
import numpy as np
import mne
import pandas as pd
import yaml
import os
import matplotlib.pyplot as plt
import copy
import scipy.stats
import multiprocessing
import time
import tqdm
import pickle
from scipy import stats


import torch
from torch.utils.data import DataLoader
from torch import nn
from torchvision.models import resnet18, resnet50
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision import transforms
from pytorch_utils.pytorch_utils import ToJpeg, ToOpponentChannel, collate_fn, record_activations, evaluate
from oads_access.oads_access import OADS_Access, OADSImageDataset

In [2]:
# Define the paths
project_path = r"/home/c14271389"
server_path = r"/home/c14271389/FMG-folder"
dic_path = os.path.join(project_path, "EventsID_Dictionary.csv")
statistics_path = os.path.join(project_path, "Stimuli")
# epoch_path = os.path.join(server_path, subject_name, "Preprocessed epochs", subject_name + "-OC&CSD-AutoReject-epo.fif")
# Model_path = os.path.join(project_path, "Res-net", "resnet50", "rgb", "best_model_23-03-23-17-14-41.pth")

In [3]:
oads = OADS_Access(basedir=f'/home/Public/Datasets/oads', n_processes=12)
train_ids, val_ids, test_ids = oads.get_train_val_test_split_indices(use_crops=False)


with open(os.path.join(statistics_path, "eeg_oads_stimulus_filenames.yml"), 'rb') as f:
    subjects = yaml.load(f, Loader=yaml.UnsafeLoader)

In [36]:
shared_images = np.intersect1d(np.array(subjects["sub_20"]), np.array(subjects["sub_35"]))
shared_images.shape
# 'Stimuli\\faba7a50277d08c9.tiff' in subjects["sub_29"]

(259,)

In [37]:
def shared(argx):
    subject_name = "sub_" + str(argx)
    epoch_path = os.path.join(server_path, subject_name, "Preprocessed epochs", subject_name + "-OC&CSD-AutoReject-epo.fif")
    # Get Epochs
    sub_epochs = mne.read_epochs(epoch_path)
    sub_epochs = sub_epochs.pick_types(csd = True)
    channel_names = sub_epochs.ch_names
    posterior_lobe = ['P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2', 'F5', 'F6']
    # posterior_channel = [i for i,j in enumerate(channel_names) if j in posterior_lobe]
    posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)
    # Get events filename
    events = posterior_epochs.events[:, 2]
    All_Images_df = pd.read_csv(dic_path, header=None)      # Read the Image dictionary
    All_Images_dict = dict(zip(All_Images_df[1], All_Images_df[0]))     # Read the Image dictionary
    events_filename = np.array([All_Images_dict[i][:-5] for i in events])
    # oz_epochs = sub_epochs.pick_channels(ch_names=['Oz'])

    # Get the subject images list
    cur_sub = copy.deepcopy(shared_images)
    for i in range(len(cur_sub)):
        if len(cur_sub[i]) < 37:
            cur_sub[i] = cur_sub[i].split('\\')[1][:-5]
        else:
            cur_sub[i] = cur_sub[i].split('\\')[2][:-5]

    # Get index of images from the events and model
    Images_index = dict()
    missing_images_counter = 0
    for i in range(len(cur_sub)):
        cur_image_index = np.where(events_filename == cur_sub[i])[0]
        # cur_image_index = [a for a, b in enumerate(events_filename) if b == sub_sti[i]]
        if (len(cur_image_index) > 0) and (cur_sub[i] in (train_ids + test_ids + val_ids)):
            Images_index[cur_sub[i]] = cur_image_index
        else:
            # print("Didn't find ", cur_sub[i], " in the events\n")
            missing_images_counter += 1
    # if missing_images_counter == 0:
    #     print("NO missing images from the events!\n")
    # else:
    #     print(f"There are {missing_images_counter} missing images from the events.")
    Image_names = list(Images_index.keys())
    Images_index_list = list(Images_index.values())

    sub_RDM = []
    for timepoints in range(posterior_epochs._data.shape[2]):
        cur_RDM = []    # Initialize the matrix
        for i in range(len(Image_names)):
            cur_dis_list = []
            row_data = posterior_epochs._data[Images_index[Image_names[i]], :, timepoints].mean(axis = 0)
            for j in range(len(Image_names)):
                if i > j:
                    cur_dis_list.append(cur_RDM[j][i])
                    continue
                elif i == j:
                    cur_dis_list.append(0.0)
                    continue
                else:
                    column_data = posterior_epochs._data[Images_index[Image_names[j]], :, timepoints].mean(axis = 0)
                    cur_coef = 1 - scipy.stats.pearsonr(row_data, column_data)[0]
                    cur_dis_list.append(cur_coef)
            cur_RDM.append(cur_dis_list)
        sub_RDM.append(cur_RDM)
        # print(f"Done with timepoint#{timepoints}")
    print(f"Shape of subject {subject_name} is: {len(Image_names)}")

    np.save(os.path.join(project_path, "Noise_Ceiling", subject_name + "_RDM(shared).npy"), np.array(sub_RDM))

In [38]:
start_sub = 20
end_sub = 35
with multiprocessing.Pool(end_sub - start_sub + 1) as p:
    start = time.time()
    sub_iter = list(range(start_sub, end_sub+1))
    p.map(shared, iterable = sub_iter)
    end = time.time()
    print(end-start)

Reading /home/c14271389/FMG-folder/sub_24/Preprocessed epochs/sub_24-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_20/Preprocessed epochs/sub_20-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_30/Preprocessed epochs/sub_30-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_22/Preprocessed epochs/sub_22-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_26/Preprocessed epochs/sub_26-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_28/Preprocessed epochs/sub_28-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_29/Preprocessed epochs/sub_29-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_27/Preprocessed epochs/sub_27-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_25/Preprocessed epochs/sub_25-OC&CSD-AutoReject-epo.fif ...
Reading /home/c14271389/FMG-folder/sub_21/Preprocessed epochs/sub_21-OC&CSD-AutoReject-epo.fif ...
Reading /h

  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
Not setting metadata
3470 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
Not setting metadata
3770 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
        0 CTF compensation matrices available
Not setting metadata
3876 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


Not setting metadata
3921 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Not setting metadata
3968 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
3997 matching events found
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
        0 CTF compensation matrices available
    Found the data of interest:
        t =     -99.61 ...     400.39 ms
        0 CTF compensation matrices available
    Found the data of interest:
        t =     

  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


Not setting metadata
4788 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Not setting metadata
4747 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


Not setting metadata
4650 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Not setting metadata
4646 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Not setting metadata
4754 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Not setting metadata
4741 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


Not setting metadata
4766 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  posterior_epochs = sub_epochs.pick_channels(ch_names=posterior_lobe)


Shape of subject sub_30 is: 211
Shape of subject sub_24 is: 211
Shape of subject sub_34 is: 211
Shape of subject sub_21 is: 211
Shape of subject sub_20 is: 211
Shape of subject sub_31 is: 211
Shape of subject sub_35 is: 211
Shape of subject sub_23 is: 211
Shape of subject sub_33 is: 211
Shape of subject sub_28 is: 211
Shape of subject sub_22 is: 211
Shape of subject sub_29 is: 211
Shape of subject sub_27 is: 211
Shape of subject sub_26 is: 211
Shape of subject sub_32 is: 211
Shape of subject sub_25 is: 211
9161.32809996605


In [39]:
shared_rdms = []
for i in range(20, 36):
    sub = "sub_" + str(i)
    cur_sub = np.load(os.path.join(project_path, "Noise_Ceiling", sub + "_RDM(shared).npy"))
    shared_rdms.append(cur_sub)

shared_rdms = np.array(shared_rdms)
print(shared_rdms.shape)

(16, 513, 211, 211)


In [44]:
all_low = []
all_high = []
for timepoints in range(513):
    cur_time = shared_rdms[:, timepoints, :, :]
    zscore = np.array([scipy.stats.zscore(i, axis = None) for i in cur_time])
    avg_cur = zscore.mean(axis = 0)
    lower_bound = []
    upper_bound = []
    for i in range(20, 36):
        cur_sub = shared_rdms[i-20, timepoints, :, :]
        cur_sub = scipy.stats.zscore(cur_sub, axis = None)
        left_out = list(range(0,i-20)) + list(range(i-19,16))
        left_out = shared_rdms[left_out, timepoints, :, :]
        z_left_out = np.array([scipy.stats.zscore(i, axis = None) for i in left_out])
        avg_left_out = z_left_out.mean(axis = 0)

        high_cur = 1 - scipy.stats.pearsonr(avg_cur.flatten(), cur_sub.flatten())[0]
        upper_bound.append(high_cur)
        low_cur = 1 - scipy.stats.pearsonr(avg_left_out.flatten(), cur_sub.flatten())[0]
        lower_bound.append(low_cur)
    all_high.append(np.array(upper_bound).mean())
    all_low.append(np.array(lower_bound).mean())

np.save(os.path.join(project_path, "upper_bound.npy"), np.array(all_high))
np.save(os.path.join(project_path, "lower_bound.npy"), np.array(all_low))

In [43]:
print(avg_cur.flatten()[:10])
print(avg_left_out.flatten()[:10])


[-2.38931816  0.47869604 -0.32381786 -0.26619157  0.10346329  0.24415647
  0.29040974  0.16914886 -0.18938209  0.49685608]
[-2.38931816  0.47869604 -0.32381786 -0.26619157  0.10346329  0.24415647
  0.29040974  0.16914886 -0.18938209  0.49685608]
