# Imports

In [59]:
import transformers
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor, CLIPTokenizerFast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import random
import math
import scipy.io as sio
from scipy import signal
import nibabel as nib
from pathlib import Path
from gensim.models import Word2Vec
import re
import pickle
import pandas as pd
import gzip
import os

In [36]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Load word and fMRI data

### Load 3d fMRI data

In [37]:
NUM_SUBJS = 8
subjects_fmri = [] #stores all 8 subject fmri np arrays

fMRI_folder = Path('./doi_10.5061_dryad.gt413__v1')
assert fMRI_folder.exists(), f"Foldder: {fMRI_folder} does not exist."
with open(fMRI_folder / 'fmri_indices', 'rb') as f:
    fmri_indices = pickle.load(f)

for subj_id in range(8):
    print("Subject:",subj_id)
#     fmri_file_name = str(subj_id) + '_masked_2d.npy'
#     fmri = np.load(fMRI_folder / fmri_file_name)  
    fmri_file_name = str(subj_id) + '_smooth_detrend_nifti_4d.nii'
    fmri = nib.load(fMRI_folder / fmri_file_name)
    fmri = np.array(fmri.dataobj)
    assert isinstance(fmri, np.ndarray), f"Imported fmri_scan for subject {subj_id} is not of type numpy.ndarray"
    assert(fmri.ndim) == 4, f"Imported fmri_scan for subject {subj_id} is not 4 dimensional"
    subjects_fmri.append(fmri)

Subject: 0
Subject: 1
Subject: 2
Subject: 3
Subject: 4
Subject: 5
Subject: 6
Subject: 7


# Load words

In [38]:
feature_matrix = np.zeros((5176,195)) #stores the feature vectors as a row for each word
feature_names = [] #stores the names of all features in order
feature_types = {} #stores the types of features and all the names of the features for each type

features = sio.loadmat(fMRI_folder / 'story_features.mat')
feature_count = 0
for feature_type in features['features'][0]:
    feature_types[feature_type[0][0]] = []
    if isinstance(feature_type[1][0], str):
        feature_types[feature_type[0][0]].append(feature_type[1][0])
        feature_names.append(feature_type[1][0])
    else:
        for feature in feature_type[1][0]:
            feature_types[feature_type[0][0]].append(feature[0])
            feature_names.append(feature[0])
    feature_matrix[:, feature_count:feature_count+feature_type[2].shape[1]] = feature_type[2] #adds the (5176xN) feature values to the feature matrix for the current feature group
    feature_count += feature_type[2].shape[1]

In [39]:
words_info = [] #stores tuples of (word, time, features) sorted by time appeared

mat_file = fMRI_folder / 'subject_1.mat' #only looks at the first subject file, somewhere it said all the timings were the same so this should be safe
mat_contents = sio.loadmat(mat_file)
for count, row in enumerate(mat_contents['words'][0]):
    word_value = row[0][0][0][0]
    time = row[1][0][0]
    word_tuple = (word_value, time, feature_matrix[count,:])
    words_info.append(word_tuple)

### Align fMRI scans with sets of 4 words

In [63]:
#for each word, get the next 4 fMRI scans weighted by the gaussian window for each subject
#then save the word fMRI scans for each subject and the words in an pandas file
window = signal.windows.gaussian(16, std=1) #gaussian window for the 4 fMRI scans
subject_words_dict = [{'file_name':[], 'word':[], 'time':[]} for i in range(8)]
for word_count in words_info:
    word = word_count[0]
    time = word_count[1]
    print(word, time)
    fmri_count = 0
    subject_scans = []
    for i in range(1,17):
        delay = 0.5*i #time after word was read
        try:
            curr_fmri_idx = fmri_indices.index((time + delay)/2) #checks if an fMRI scan happens at this time point
            weight = window[int(2*delay)-1]
            for count, subject in enumerate(subjects_fmri):
                if fmri_count == 0:
                    subject_scans.append(weight*subject[:,:,:,curr_fmri_idx])
                else:
                    subject_scans[count] += weight*subject[:,:,:,curr_fmri_idx]
            fmri_count += 1
        except Exception as e:
            #print(e)
            pass
    print(fmri_count)
    if fmri_count == 4:
        for count, subject in enumerate(subjects_fmri):
            #save filename with (word, time) in file
            file_name = "./word_fmris/" + str(count) + "_subject_word_weighted_" + str(time) + ".pt"
            scan = torch.tensor(subject_scans[count])
            with open(file_name, 'wb') as f:
                torch.save(scan, f)
            subject_words_dict[count]['file_name'].append(file_name)
            subject_words_dict[count]['word'].append(word)
            subject_words_dict[count]['time'].append(time)
    for count, subject in enumerate(subjects_fmri):
        df = pd.DataFrame(subject_words_dict[count])
        df.to_csv("./" + str(count) + "_subject_word_fmri_labels.csv", index=False)

Harry 20
4
had 20.5
4
never 21
4
believed 21.5
4
he 22
4
would 22.5
4
meet 23
4
a 23.5
4
boy 24
4
he 24.5
4
hated 25
4
more 25.5
4
than 26
4
Dudley, 26.5
4
but 27
4
that 27.5
4
was 28
4
before 28.5
4
he 29
4
met 29.5
4
Draco 30
4
Malfoy. 30.5
4
Still, 31
4
first-year 31.5
4
Gryffindors 32
4
only 32.5
4
had 33
4
Potions 33.5
4
with 34
4
the 34.5
4
Slytherins, 35
4
so 35.5
4
they 36
4
didn't 36.5
4
have 37
4
to 37.5
4
put 38
4
up 38.5
4
with 39
4
Malfoy 39.5
4
much. 40
4
Or 40.5
4
at 41
4
least, 41.5
4
they 42
4
didn't 42.5
4
until 43
4
they 43.5
4
spotted 44
4
a 44.5
4
notice 45
4
pinned 45.5
4
up 46
4
in 46.5
4
the 47
4
Gryffindor 47.5
4
common 48
4
room 48.5
4
that 49
4
made 49.5
4
them 50
4
all 50.5
4
groan. 51
4
Flying 51.5
4
lessons 52
4
would 52.5
4
be 53
4
starting 53.5
4
on 54
4
Thursday 54.5
4
-- 55
4
and 55.5
4
Gryffindor 56
4
and 56.5
4
Slytherin 57
4
would 57.5
4
be 58
4
learning 58.5
4
together. 59
4
+ 59.5
4
"Typical," 60
4
said 60.5
4
Harry 61
4
darkly. 61.5
4
"Just 62
4


At 349.5
4
three-thirty 350
4
that 350.5
4
afternoon, 351
4
Harry, 351.5
4
Ron, 352
4
and 352.5
4
the 353
4
other 353.5
4
Gryffindors 354
4
hurried 354.5
4
down 355
4
the 355.5
4
front 356
4
steps 356.5
4
onto 357
4
the 357.5
4
grounds 358
4
for 358.5
4
their 359
4
first 359.5
4
flying 360
4
lesson. 360.5
4
It 361
4
was 361.5
4
a 362
4
clear, 362.5
4
breezy 363
4
day, 363.5
4
and 364
4
the 364.5
4
grass 365
4
rippled 365.5
4
under 366
4
their 366.5
4
feet 367
4
as 367.5
4
they 368
4
marched 368.5
4
down 369
4
the 369.5
4
sloping 370
4
lawns 370.5
4
toward 371
4
a 371.5
4
smooth, 372
4
flat 372.5
4
lawn 373
4
on 373.5
4
the 374
4
opposite 374.5
4
side 375
4
of 375.5
4
the 376
4
grounds 376.5
4
to 377
4
the 377.5
4
forbidden 378
4
forest, 378.5
4
whose 379
4
trees 379.5
4
were 380
4
swaying 380.5
4
darkly 381
4
in 381.5
4
the 382
4
distance. 382.5
4
+ 383
4
The 383.5
4
Slytherins 384
4
were 384.5
4
already 385
4
there, 385.5
4
and 386
4
so 386.5
4
were 387
4
twenty 387.5
4
broomsticks 38

+ 706
4
Malfoy 706.5
4
smiled 707
4
nastily. 707.5
4
"I 708
4
think 708.5
4
I'll 709
4
leave 709.5
4
it 710
4
somewhere 710.5
4
for 711
4
Longbottom 711.5
4
to 712
4
find 712.5
4
-- 713
4
how 713.5
4
about 714
4
-- 714.5
4
up 715
4
a 715.5
4
tree?" 716
4
+ 716.5
4
"Give 717
4
it 717.5
4
@here!" 718
4
Harry 718.5
4
yelled, 719
4
but 719.5
4
Malfoy 720
4
had 720.5
4
leapt 721
4
onto 721.5
4
his 722
4
broomstick 722.5
4
and 723
4
taken 723.5
4
off. 724
4
He 724.5
4
hadn't 725
4
been 725.5
4
lying, 726
4
he 726.5
4
@could 727
4
fly 727.5
4
well. 728
4
Hovering 728.5
4
level 729
4
with 729.5
4
the 730
4
topmost 730.5
4
branches 731
4
of 731.5
4
an 732
4
oak 732.5
4
he 733
4
called, 733.5
4
"Come 734
4
and 734.5
4
get 735
4
it, 735.5
4
Potter!" 736
4
+ 736.5
4
Harry 737
4
grabbed 737.5
4
his 738
4
broom. 738.5
4
+ 739
4
@"No!" 739.5
4
shouted 740
4
Hermione 740.5
4
Granger. 741
4
"Madam 741.5
4
Hooch 742
4
told 742.5
4
us 743
4
not 743.5
4
to 744
4
move 744.5
4
-- 745
4
you'll 745.5
4
get 74

as 1033
4
he 1033.5
4
imagined 1034
4
it, 1034.5
4
watching 1035
4
Ron 1035.5
4
and 1036
4
the 1036.5
4
others 1037
4
becoming 1037.5
4
wizards 1038
4
while 1038.5
4
he 1039
4
stumped 1039.5
4
around 1040
4
the 1040.5
4
grounds 1041
4
carrying 1041.5
4
Hagrid's 1042
4
bag. 1042.5
4
+ 1043
4
Professor 1043.5
4
McGonagall 1044
4
stopped 1044.5
4
outside 1045
4
a 1045.5
4
classroom. 1046
4
She 1046.5
4
opened 1047
4
the 1047.5
4
door 1048
4
and 1048.5
4
poked 1049
4
her 1049.5
4
head 1050
4
inside. 1050.5
4
+ 1051
4
"Excuse 1051.5
4
me, 1052
4
Professor 1052.5
4
Flitwick, 1053
4
could 1053.5
4
I 1054
4
borrow 1054.5
4
Wood 1055
4
for 1055.5
4
a 1056
4
moment?" 1056.5
4
+ 1057
4
Wood? 1057.5
4
thought 1058
4
Harry, 1058.5
4
bewildered; 1059
4
was 1059.5
4
Wood 1060
4
a 1060.5
4
cane 1061
4
she 1061.5
4
was 1062
4
going 1062.5
4
to 1063
4
use 1063.5
4
on 1064
4
him? 1064.5
4
But 1065
4
Wood 1065.5
4
turned 1066
4
out 1066.5
4
to 1067
4
be 1067.5
4
a 1068
4
person, 1068.5
4
a 1069
4
burly 10

tell 1333.5
4
you, 1334
4
we're 1334.5
4
going 1335
4
to 1335.5
4
win 1336
4
that 1336.5
4
Quidditch 1337
4
Cup 1337.5
4
for 1338
4
sure 1338.5
4
this 1339
4
year," 1339.5
4
said 1340
4
Fred. 1340.5
4
"We 1341
4
haven't 1341.5
4
won 1342
4
since 1342.5
4
Charlie 1343
4
left, 1343.5
4
but 1344
4
this 1344.5
4
year's 1345
4
team 1345.5
4
is 1346
4
going 1346.5
4
to 1347
4
be 1347.5
4
brilliant. 1348
3
You 1348.5
3
must 1349
3
be 1349.5
3
good, 1350
2
Harry, 1350.5
2
Wood 1351
2
was 1351.5
2
almost 1352
1
skipping 1352.5
1
when 1353
1
he 1353.5
1
told 1354
0
us." 1354.5
0
+ 1355
0
"Anyway, 1355.5
0
we've 1356
0
got 1356.5
0
to 1357
0
go, 1357.5
0
Lee 1358
0
Jordan 1358.5
0
reckons 1359
0
he's 1359.5
0
found 1360
0
a 1360.5
0
new 1361
0
secret 1361.5
0
passageway 1362
0
out 1362.5
0
of 1363
0
the 1363.5
0
school." 1364
0
+ 1364.5
0
"Bet 1365
0
it's 1365.5
0
that 1366
1
one 1366.5
1
behind 1367
1
the 1367.5
1
statue 1368
2
of 1368.5
2
Gregory 1369
2
the 1369.5
2
Smarmy 1370
3
that 1370.5
3


picked 1669
4
up 1669.5
4
their 1670
4
wands, 1670.5
4
and 1671
4
crept 1671.5
4
across 1672
4
the 1672.5
4
tower 1673
4
room, 1673.5
4
down 1674
4
the 1674.5
4
spiral 1675
4
staircase, 1675.5
4
and 1676
4
into 1676.5
4
the 1677
4
Gryffindor 1677.5
4
common 1678
4
room. 1678.5
4
A 1679
4
few 1679.5
4
embers 1680
4
were 1680.5
4
still 1681
4
glowing 1681.5
4
in 1682
4
the 1682.5
4
fireplace, 1683
4
turning 1683.5
4
all 1684
4
the 1684.5
4
armchairs 1685
4
into 1685.5
4
hunched 1686
4
black 1686.5
4
shadows. 1687
4
They 1687.5
4
had 1688
4
almost 1688.5
4
reached 1689
4
the 1689.5
4
portrait 1690
4
hole 1690.5
4
when 1691
4
a 1691.5
4
voice 1692
4
spoke 1692.5
4
from 1693
4
the 1693.5
4
chair 1694
4
nearest 1694.5
4
them, 1695
4
"I 1695.5
4
can't 1696
4
believe 1696.5
4
you're 1697
4
going 1697.5
4
to 1698
4
do 1698.5
4
this, 1699
4
Harry." 1699.5
4
+ 1700
4
A 1700.5
4
lamp 1701
4
flickered 1701.5
4
on. 1702
4
It 1702.5
4
was 1703
4
Hermione 1703.5
4
Granger, 1704
4
wearing 1704.5
4
a 17

high 2004
4
windows. 2004.5
4
At 2005
4
every 2005.5
4
turn 2006
4
Harry 2006.5
4
expected 2007
4
to 2007.5
4
run 2008
4
into 2008.5
4
Filch 2009
4
or 2009.5
4
Mrs. 2010
4
Norris, 2010.5
4
but 2011
4
they 2011.5
4
were 2012
4
lucky. 2012.5
4
They 2013
4
sped 2013.5
4
up 2014
4
a 2014.5
4
staircase 2015
4
to 2015.5
4
the 2016
4
third 2016.5
4
floor 2017
4
and 2017.5
4
tiptoed 2018
4
toward 2018.5
4
the 2019
4
trophy 2019.5
4
room. 2020
4
+ 2020.5
4
Malfoy 2021
4
and 2021.5
4
Crabbe 2022
4
weren't 2022.5
4
there 2023
4
yet. 2023.5
4
The 2024
4
crystal 2024.5
4
trophy 2025
4
cases 2025.5
4
glimmered 2026
4
where 2026.5
4
the 2027
4
moonlight 2027.5
4
caught 2028
4
them. 2028.5
4
Cups, 2029
4
shields, 2029.5
4
plates, 2030
4
and 2030.5
4
statues 2031
4
winked 2031.5
4
silver 2032
4
and 2032.5
4
gold 2033
4
in 2033.5
4
the 2034
4
darkness. 2034.5
4
They 2035
4
edged 2035.5
4
along 2036
4
the 2036.5
4
walls, 2037
4
keeping 2037.5
4
their 2038
4
eyes 2038.5
4
on 2039
4
the 2039.5
4
doors 2040

OF 2308.5
4
BED!" 2309
4
Peeves 2309.5
4
bellowed, 2310
4
"STUDENTS 2310.5
4
OUT 2311
4
OF 2311.5
4
BED 2312
4
DOWN 2312.5
4
THE 2313
4
CHARMS 2313.5
4
CORRIDOR!" 2314
4
+ 2314.5
4
Ducking 2315
4
under 2315.5
4
Peeves, 2316
4
they 2316.5
4
ran 2317
4
for 2317.5
4
their 2318
4
lives, 2318.5
4
right 2319
4
to 2319.5
4
the 2320
4
end 2320.5
4
of 2321
4
the 2321.5
4
corridor 2322
4
where 2322.5
4
they 2323
4
slammed 2323.5
4
into 2324
4
a 2324.5
4
door 2325
4
-- 2325.5
4
and 2326
4
it 2326.5
4
was 2327
4
locked. 2327.5
4
+ 2328
4
"This 2328.5
4
is 2329
4
it!" 2329.5
4
Ron 2330
4
moaned, 2330.5
4
as 2331
4
they 2331.5
4
pushed 2332
4
helplessly 2332.5
4
at 2333
4
the 2333.5
4
door, 2334
4
"We're 2334.5
4
done 2335
4
for! 2335.5
4
This 2336
4
is 2336.5
4
the 2337
4
end!" 2337.5
4
+ 2338
4
They 2338.5
4
could 2339
4
hear 2339.5
4
footsteps, 2340
4
Filch 2340.5
4
running 2341
4
as 2341.5
4
fast 2342
4
as 2342.5
4
he 2343
4
could 2343.5
4
toward 2344
4
Peeves's 2344.5
4
shouts. 2345
4
+ 2345.5


4
you 2611.5
4
see 2612
4
what 2612.5
4
it 2613
4
was 2613.5
4
standing 2614
4
on?" 2614.5
4
+ 2615
4
"The 2615.5
4
floor?" 2616
4
Harry 2616.5
4
suggested. 2617
4
"I 2617.5
4
wasn't 2618
4
looking 2618.5
4
at 2619
4
its 2619.5
4
feet, 2620
4
I 2620.5
4
was 2621
4
too 2621.5
4
busy 2622
4
with 2622.5
4
its 2623
4
heads." 2623.5
4
+ 2624
4
"No, 2624.5
4
@not 2625
4
the 2625.5
4
floor. 2626
4
It 2626.5
4
was 2627
4
standing 2627.5
4
on 2628
4
a 2628.5
4
trapdoor. 2629
4
It's 2629.5
4
obviously 2630
4
guarding 2630.5
4
something." 2631
4
+ 2631.5
4
She 2632
4
stood 2632.5
4
up, 2633
4
glaring 2633.5
4
at 2634
4
them. 2634.5
4
+ 2635
4
"I 2635.5
4
hope 2636
4
you're 2636.5
4
pleased 2637
4
with 2637.5
4
yourselves. 2638
4
We 2638.5
4
could 2639
4
all 2639.5
4
have 2640
4
been 2640.5
4
killed 2641
4
-- 2641.5
4
or 2642
4
worse, 2642.5
4
expelled. 2643
4
Now, 2643.5
4
if 2644
4
you 2644.5
4
don't 2645
4
mind, 2645.5
4
I'm 2646
4
going 2646.5
4
to 2647
4
bed." 2647.5
4
+ 2648
3
Ron 2648.5
3
s

### Create Dataset

In [102]:
from torch.utils.data import Dataset

class SubjectImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, num_words):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.num_words = num_words

    def __len__(self):
        return math.floor(len(self.img_labels)/self.num_words)

    def __getitem__(self, idx):
        full_image = None
        full_word = None
        for word_count in range(self.num_words):
            word_idx = idx*self.num_words + word_count
            img_path = os.path.join(self.img_dir, self.img_labels.iloc[word_idx, 0])
            with open(img_path, 'rb') as f:
                image = torch.load(img_path)
            word = self.img_labels.iloc[word_idx, 1]
            if word_count == 0:
                full_image = torch.unsqueeze(image,0)
                full_word = word
            else:
                full_image = torch.cat((full_image, torch.unsqueeze(image,0)))
                full_word += " " + word
        return full_image, full_word

In [103]:
from torch.utils.data import DataLoader
subject_dataloaders = []
NUM_WORDS = 4
BATCH_SIZE = 32
for i in range(NUM_SUBJS):
    label_filename = "./" + str(count) + "_subject_word_fmri_labels.csv"
    dataset = SubjectImageDataset(label_filename, "", NUM_WORDS)
    train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    subject_dataloaders.append(train_dataloader)

In [106]:
for i in range(NUM_SUBJS):
    print("Subject:", i)
    print("\tLength of DataLoader:", len(subject_dataloaders[i].dataset))
    for j in range(math.floor(len(subject_dataloaders[i].dataset)/BATCH_SIZE)):
        images, text = next(iter(subject_dataloaders[i]))
        print(f"\tImage batch shape: {images.shape}")
        print(f"\tText batch shape: {len(text)}")

Subject: 0
	Length of DataLoader: 1236
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32

	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
Subject: 3
	Length of DataLoader: 1236
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32
	Image batch shape: torch.Size([32, 4, 53, 60, 50])
	Text batch shape: 32

KeyboardInterrupt: 