In [1]:
import os
# os.chdir('./net2neuro')
from tqdm import tqdm
import pickle
import mne
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
# from simpleconv import SimpleConv
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
base_dir = '/scratch/jd5697/cv_project'

In [4]:
valid_epochs_all_train_fmri_subset = pickle.load(open(os.path.join(base_dir, 'valid_epochs_all_train_fmri_subset.pickle'), 'rb'))
valid_epochs_all_test_fmri_subset = pickle.load(open(os.path.join(base_dir, 'valid_epochs_all_test_fmri_subset.pickle'), 'rb'))
valid_epochs_all_test_small_fmri_subset = pickle.load(open(os.path.join(base_dir, 'valid_epochs_all_test_small_fmri_subset.pickle'), 'rb'))

print(len(valid_epochs_all_train_fmri_subset))
print(len(valid_epochs_all_test_fmri_subset))
print(len(valid_epochs_all_test_small_fmri_subset))

24960
9600
400


In [5]:
def merge_into_epochs(valid_epochs, fmri_df):

    valid_epochs.metadata =  valid_epochs.metadata.merge(fmri_df, left_on='image_path', right_on='stimulus', how='left')
    valid_epochs.metadata.reset_index(drop=True, inplace=True)

In [6]:
subs = ['01', '02', '03']

# Read files for each of the subjects and concat into one dataframe
dfs = []
for sub in subs:
    df = pickle.load(open(os.path.join(base_dir, 'sub-{}_responses_fmri_svd.pkl'.format(sub)), 'rb'))
    dfs.append(df)

In [7]:
def process_name(name):
    parts = name.split('/')
    filename = parts[-1]
    folder_name = '_'.join(filename.split('_')[:-1])
    result = f"{folder_name}/{filename}"
    return result

In [8]:
for i, df in enumerate(dfs):
    dfs[i]['stimulus'] = dfs[i]['stimulus'].apply(process_name)

In [9]:
valid_epochs_train_list = []

for i, df in enumerate(dfs):
    valid_epochs = valid_epochs_all_train_fmri_subset.copy()
    merge_into_epochs(valid_epochs, df)
    valid_epochs_train_list.append(valid_epochs)

valid_epochs_all_train_fmri_subset = mne.concatenate_epochs(valid_epochs_train_list)
assert valid_epochs_all_train_fmri_subset.metadata['stimulus'].isna().sum() == 0
print(len(valid_epochs_all_train_fmri_subset))
print(valid_epochs_all_train_fmri_subset.metadata.head())

Replacing existing metadata with 28 columns
Replacing existing metadata with 28 columns
Replacing existing metadata with 28 columns
Adding metadata with 28 columns
74880 matching events found
No baseline correction applied
74880
   trial_type_x  image_nr  category_nr  exemplar_nr  test_image_nr  \
2           exp      7522          627           10            NaN   
8           exp      6802          567           10            NaN   
9           exp     17050         1421           10            NaN   
10          exp      5266          439           10            NaN   
13          exp      8410          701           10            NaN   

    things_category_nr  things_image_nr  things_exemplar_nr  \
2                627.0           9054.0                10.0   
8                567.0           8217.0                10.0   
9               1421.0          20164.0                10.0   
10               439.0           6348.0                10.0   
13               701.0          100

In [10]:
valid_epochs_test_list = []

for i, df in enumerate(dfs):
    valid_epochs = valid_epochs_all_test_fmri_subset.copy()
    merge_into_epochs(valid_epochs, df)
    valid_epochs_test_list.append(valid_epochs)

valid_epochs_all_test_fmri_subset = mne.concatenate_epochs(valid_epochs_test_list)
assert valid_epochs_all_test_fmri_subset.metadata['stimulus'].isna().sum() == 0
print(len(valid_epochs_all_test_fmri_subset))
print(valid_epochs_all_test_fmri_subset.metadata.head())

Replacing existing metadata with 28 columns
Replacing existing metadata with 28 columns
Replacing existing metadata with 28 columns
Adding metadata with 28 columns
28800 matching events found
No baseline correction applied
28800
   trial_type_x  image_nr  category_nr  exemplar_nr  test_image_nr  \
3           exp      2950          246           10            NaN   
28          exp     18994         1583           10            NaN   
33          exp      3034          253           10            NaN   
73          exp      1342          112           10            NaN   
74          exp     12682         1057           10            NaN   

    things_category_nr  things_image_nr  things_exemplar_nr  \
3                246.0           3648.0                10.0   
28              1583.0          22357.0                10.0   
33               253.0           3735.0                10.0   
73               112.0           1704.0                10.0   
74              1057.0          150

In [11]:
valid_epochs_test_small_list = []

for i, df in enumerate(dfs):
    valid_epochs = valid_epochs_all_test_small_fmri_subset.copy()
    merge_into_epochs(valid_epochs, df)
    valid_epochs_test_small_list.append(valid_epochs)

valid_epochs_all_test_small_fmri_subset = mne.concatenate_epochs(valid_epochs_test_small_list)
assert valid_epochs_all_test_small_fmri_subset.metadata['stimulus'].isna().sum() == 0
print(len(valid_epochs_all_test_small_fmri_subset))
print(valid_epochs_all_test_small_fmri_subset.metadata.head())

Replacing existing metadata with 28 columns
Replacing existing metadata with 28 columns
Replacing existing metadata with 28 columns
Adding metadata with 28 columns
1200 matching events found
No baseline correction applied
1200
    trial_type_x  image_nr  category_nr  exemplar_nr  test_image_nr  \
14          test     22479          436           13           31.0   
84          test     22540         1742           13           92.0   
88          test     22534         1542           13           86.0   
90          test     22545         1786           13           97.0   
116         test     22482          514           13           34.0   

     things_category_nr  things_image_nr  things_exemplar_nr  \
14                436.0           6312.0                15.0   
84               1742.0          24559.0                13.0   
88               1542.0          21793.0                15.0   
90               1786.0          25182.0                15.0   
116               514.0   

In [12]:
with open('valid_epochs_all_train_meg_fmri_combined.pickle', 'wb') as f:
    pickle.dump(valid_epochs_all_train_fmri_subset, f, protocol=pickle.HIGHEST_PROTOCOL)

In [13]:
with open('valid_epochs_all_test_meg_fmri_combined.pickle', 'wb') as f:
    pickle.dump(valid_epochs_all_test_fmri_subset, f, protocol=pickle.HIGHEST_PROTOCOL)

In [14]:
with open('valid_epochs_all_test_small_meg_fmri_combined.pickle', 'wb') as f:
    pickle.dump(valid_epochs_all_test_small_fmri_subset, f, protocol=pickle.HIGHEST_PROTOCOL)