# Imports

In [57]:
import transformers
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor, CLIPTokenizerFast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import random
import math
import scipy.io as sio
import nibabel as nib
from pathlib import Path

# Load word and fMRI data

### Load 3d fMRI data

In [60]:
NUM_SUBJS = 8
subjects_fmri = [] #stores all 8 subject fmri np arrays

fMRI_folder = Path('./doi_10.5061_dryad.gt413__v1')
assert fMRI_folder.exists(), f"Foldder: {fMRI_folder} does not exist."

for subj_id in range(8):
#     fmri_file_name = str(subj_id) + '_masked_2d.npy'
#     fmri = np.load(fMRI_folder / fmri_file_name)  
    fmri_file_name = str(subj_id) + '_smooth_nifti_4d.nii'
    fmri = nib.load(fMRI_folder / fmri_file_name)
    fmri = np.array(fmri.dataobj)
    assert isinstance(fmri, np.ndarray), f"Imported fmri_scan for subject {subj_id} is not of type numpy.ndarray"
    assert(fmri.ndim) == 4, f"Imported fmri_scan for subject {subj_id} is not 4 dimensional"
    subjects_fmri.append(fmri)

# Load words

In [61]:
feature_matrix = np.zeros((5176,195)) #stores the feature vectors as a row for each word
feature_names = [] #stores the names of all features in order
feature_types = {} #stores the types of features and all the names of the features for each type

features = sio.loadmat(fMRI_folder / 'story_features.mat')
feature_count = 0
for feature_type in features['features'][0]:
    feature_types[feature_type[0][0]] = []
    if isinstance(feature_type[1][0], str):
        feature_types[feature_type[0][0]].append(feature_type[1][0])
        feature_names.append(feature_type[1][0])
    else:
        for feature in feature_type[1][0]:
            feature_types[feature_type[0][0]].append(feature[0])
            feature_names.append(feature[0])
    feature_matrix[:, feature_count:feature_count+feature_type[2].shape[1]] = feature_type[2] #adds the (5176xN) feature values to the feature matrix for the current feature group
    feature_count += feature_type[2].shape[1]

In [62]:
words_info = [] #stores tuples of (word, time, features) sorted by time appeared

mat_file = fMRI_folder / 'subject_1.mat' #only looks at the first subject file, somewhere it said all the timings were the same so this should be safe
mat_contents = sio.loadmat(mat_file)
for count, row in enumerate(mat_contents['words'][0]):
    word_value = row[0][0][0][0]
    time = row[1][0][0]
    word_tuple = (word_value, time, feature_matrix[count,:])
    words_info.append(word_tuple)

### Align fMRI scans with sets of 4 words

In [64]:
subjects_samples = [[] for i in range(NUM_SUBJS)] #stores lists of all the samples for each subject

word_count = 0
while word_count < len(words_info) - 8:
    #gets the 4 input words, and the 4 consecutive words while verifying they were read in sequence
    scan_words = []
    start_time = words_info[word_count][1]
    in_sequence = True #tracks if the words are in sequence or not
    for i in range(8):
        word_info = words_info[word_count + i]
        if word_info[1] != start_time + 0.5*i:
            #if some of the words are not in sequence, skip forward 1 word after innter loop
            in_sequence = False
        scan_words.append(word_info[0])
    if not in_sequence:
        word_count +=1
        continue
    fmri_time = start_time + 2 #effect of reading words is assumed to start 2 seconds after and end 8 seconds after
    fmri_index = fmri_time//2 #since a scan happens every two seconds, the index is the time divided by 2
    if not isinstance(fmri_index, np.int32):
        #if the first word is not aligned with the fmri scan (i.e. its not the first word in a TR)
        word_count += 1
        continue
    for count, subject in enumerate(subjects_fmri):
        #adds tuple of (fmri_scan, four words)
        subjects_samples[count].append((subject[:,:,:,fmri_index+2], scan_words[0:4]))
    print("Created sample:")
    print("\tScan time:", str(start_time))
    print("\tInput words:", str(scan_words[0:4]))
    #if successful, skip forward to the next set of 4 words
    word_count += 4

print("Total number of samples:", str(len(subjects_samples[0])))

Created sample:
	Scan time: 20
	Input words: ['Harry', 'had', 'never', 'believed']
Created sample:
	Scan time: 22
	Input words: ['he', 'would', 'meet', 'a']
Created sample:
	Scan time: 24
	Input words: ['boy', 'he', 'hated', 'more']
Created sample:
	Scan time: 26
	Input words: ['than', 'Dudley,', 'but', 'that']
Created sample:
	Scan time: 28
	Input words: ['was', 'before', 'he', 'met']
Created sample:
	Scan time: 30
	Input words: ['Draco', 'Malfoy.', 'Still,', 'first-year']
Created sample:
	Scan time: 32
	Input words: ['Gryffindors', 'only', 'had', 'Potions']
Created sample:
	Scan time: 34
	Input words: ['with', 'the', 'Slytherins,', 'so']
Created sample:
	Scan time: 36
	Input words: ['they', "didn't", 'have', 'to']
Created sample:
	Scan time: 38
	Input words: ['put', 'up', 'with', 'Malfoy']
Created sample:
	Scan time: 40
	Input words: ['much.', 'Or', 'at', 'least,']
Created sample:
	Scan time: 42
	Input words: ['they', "didn't", 'until', 'they']
Created sample:
	Scan time: 44
	Input w

Created sample:
	Scan time: 1716
	Input words: ['snapped,', '"Percy', '--', "he's"]
Created sample:
	Scan time: 1718
	Input words: ['a', 'prefect,', "he'd", 'put']
Created sample:
	Scan time: 1720
	Input words: ['a', 'stop', 'to', 'this."']
Created sample:
	Scan time: 1722
	Input words: ['+', 'Harry', "couldn't", 'believe']
Created sample:
	Scan time: 1724
	Input words: ['anyone', 'could', 'be', 'so']
Created sample:
	Scan time: 1726
	Input words: ['interfering.', '+', '"Come', 'on,"']
Created sample:
	Scan time: 1728
	Input words: ['he', 'said', 'to', 'Ron.']
Created sample:
	Scan time: 1730
	Input words: ['He', 'pushed', 'open', 'the']
Created sample:
	Scan time: 1732
	Input words: ['portrait', 'of', 'the', 'Fat']
Created sample:
	Scan time: 1734
	Input words: ['Lady', 'and', 'climbed', 'through']
Created sample:
	Scan time: 1736
	Input words: ['the', 'hole.', '+', 'Hermione']
Created sample:
	Scan time: 1738
	Input words: ["wasn't", 'going', 'to', 'give']
Created sample:
	Scan time:

# Text encoding

In [86]:
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
text_encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')

Using cache found in C:\Users\Portable/.cache\torch\hub\huggingface_pytorch-transformers_main
Using cache found in C:\Users\Portable/.cache\torch\hub\huggingface_pytorch-transformers_main


In [87]:

test_text = "Harry had never believed"

# Tokenized input with special tokens around it (for BERT: [CLS] at the beginning and [SEP] at the end)
indexed_tokens = tokenizer.encode(test_text, add_special_tokens=False)
print(indexed_tokens)
print(len(indexed_tokens))

# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
#segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])

with torch.no_grad():
    print(text_encoder(tokens_tensor)['last_hidden_state'])
    print(text_encoder(tokens_tensor)['last_hidden_state'].shape)

[3466, 1125, 1309, 2475]
4
tensor([[[ 0.4425, -0.0343, -0.4921,  ..., -0.2058,  1.2232, -0.9892],
         [ 0.1725, -0.1542, -0.3541,  ...,  0.2818,  0.1832,  0.0317],
         [ 0.3708,  0.1688, -0.6183,  ...,  0.6845,  0.5330, -0.0646],
         [ 0.3440, -0.0137, -0.1923,  ...,  0.6708,  0.3056,  0.1424]]])
torch.Size([1, 4, 768])


### Replace words with tokenized versions

In [85]:
subject_0_tokenized = []
for sample in subjects_samples[0]:
    words = " ".join(sample[1])
    print(words)
    tokens = tokenizer.encode(words, add_special_tokens=False)
    print(tokens)
    subject_0_tokenized.append((sample[0], tokens))

Harry had never believed
[3466, 1125, 1309, 2475]
he would meet a
[1119, 1156, 2283, 170]
boy he hated more
[2298, 1119, 5687, 1167]
than Dudley, but that
[1190, 12840, 117, 1133, 1115]
was before he met
[1108, 1196, 1119, 1899]
Draco Malfoy. Still, first-year
[1987, 17312, 18880, 14467, 1183, 119, 4209, 117, 1148, 118, 1214]
Gryffindors only had Potions
[144, 1616, 16274, 8380, 1116, 1178, 1125, 18959, 6126]
with the Slytherins, so
[1114, 1103, 156, 1193, 8420, 4935, 117, 1177]
they didn't have to
[1152, 1238, 112, 189, 1138, 1106]
put up with Malfoy
[1508, 1146, 1114, 18880, 14467, 1183]
much. Or at least,
[1277, 119, 2926, 1120, 1655, 117]
they didn't until they
[1152, 1238, 112, 189, 1235, 1152]
spotted a notice pinned
[6910, 170, 4430, 11973]
up in the Gryffindor
[1146, 1107, 1103, 144, 1616, 16274, 8380]
common room that made
[1887, 1395, 1115, 1189]
them all groan. Flying
[1172, 1155, 13344, 119, 7769]
lessons would be starting
[8497, 1156, 1129, 2547]
on Thursday -- and
[1113, 

# Image encoding

# CLIP from scratch

In [5]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

In [8]:
class TestModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(TestModel, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
    
    def forward(self, samples):
        return self.linear(samples)

In [47]:
#clip model which simultaneously trains the text and image encoder as well as learning weights 
#for projecting encodings of both text and image to the same latent space

#model is based on psuedocode from original clip paper
class CLIP(nn.Module):
    def __init__(self, textEncoder, text_output_shape, imageEncoder, image_output_shape, embedding_dim):
        super(CLIP, self).__init__()
        #self.text_encoder = textEncoder(100, text_output_shape)
        self.text_encoder = textEncoder
        self.text_proj = nn.Linear(text_output_shape, embedding_dim)
        #self.image_encoder = imageEncoder(1000, image_output_shape)
        self.image_encoder = imageEncoder
        self.image_proj = nn.Linear(image_output_shape, embedding_dim)
    
    def forward(self, text, image):
        #gets encodings of text and images
        text_features = self.text_encoder(text)
        image_features = self.image_encoder(image)
        
        #projects text and images into latent space
        text_embed = self.text_proj(text_features)
        text_embed_norm = text_embed / text_embed.norm(dim=1, keepdim=True)
        image_embed = self.image_proj(image_features)
        image_embed_norm = image_embed/ image_embed.norm(dim=1, keepdim=True)
        
        return text_embed_norm, image_embed_norm

In [48]:
#trains the clip model from scratch
def train_clip(model, text_samples, image_samples, batch_size=10, num_epochs=100, lr=1e-3, temp=0.07):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        epoch_loss = 0
        epoch_correct = 0
        epoch_total = 0
        text_samples, image_samples = unison_shuffled_copies(text_samples, image_samples)
        for batch in range(math.floor(text_samples.shape[0]/batch_size)):
            optimizer.zero_grad()
            #gets embeddings for text and image batches
            start_idx = batch*batch_size
            end_idx = (batch+1)*batch_size
            text_batch, image_batch = text_samples[start_idx:end_idx], image_samples[start_idx:end_idx]
            text_embed, image_embed = model(text_batch, image_batch)
            #computes pairwise cosine similarity between text and image embeddings
            logits = torch.matmul(text_embed, image_embed.t()) * math.exp(temp)
            #symmetric loss function
            labels = torch.arange(batch_size)
            loss_text = loss_fn(logits, labels)
            loss_image = loss_fn(logits.t(), labels)
            loss = (loss_text + loss_image)/2
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            #compute accuracy
            probs = torch.softmax(logits, dim=0)
            winners = torch.argmax(probs, dim=0)
            corrects = (winners == labels)
            total_correct = corrects.sum().float().item()
            epoch_correct += total_correct
            epoch_total += batch_size
        print("Epoch:", epoch, "Training Loss:", epoch_loss, "Training Accuracy:", epoch_correct/epoch_total)

In [50]:
#trains a clip model on random input data

text = torch.rand((50, 100))
images = torch.rand((50, 1000))

text_encoder = TestModel
image_encoder = TestModel

temp=0.07
clip_model = CLIP(TestModel(100, 60), 60, TestModel(1000,200), 200, 50)

torch.autograd.set_detect_anomaly(True)
train_clip(clip_model, text, images, num_epochs=5, temp=temp)

Epoch: 0 Training Loss: 11.540533781051636 Training Accuracy: 0.12
Epoch: 1 Training Loss: 10.717428207397461 Training Accuracy: 0.68
Epoch: 2 Training Loss: 9.243805885314941 Training Accuracy: 0.96
Epoch: 3 Training Loss: 8.572237968444824 Training Accuracy: 0.98
Epoch: 4 Training Loss: 8.109791040420532 Training Accuracy: 1.0


In [51]:
#assesses the performance of the trained clip model on the entire training set
loss_fn = nn.CrossEntropyLoss()

text_embed, image_embed = clip_model(text, images)
#computes pairwise cosine similarity between text and image embeddings
logits = torch.matmul(text_embed, image_embed.t()) * math.exp(temp)
#symmetric loss function
labels = torch.arange(test_text.shape[0])
loss_text = loss_fn(logits, labels)
loss_image = loss_fn(logits.t(), labels)
loss = (loss_text + loss_image)/2
#compute accuracy
probs = torch.softmax(logits, dim=0)
winners = torch.argmax(probs, dim=0)
corrects = (winners == labels)
total_correct = corrects.sum().float().item()
print("Training Loss:", loss.item(), "Training Accuracy:", total_correct/test_text.shape[0])

Training Loss: 3.085871696472168 Training Accuracy: 1.0


In [52]:
#assesses the performance of the trained clip model on a test set

loss_fn = nn.CrossEntropyLoss()

test_text = torch.rand((50, 100))
test_images = torch.rand((50, 1000))

text_embed, image_embed = clip_model(test_text, test_images)
#computes pairwise cosine similarity between text and image embeddings
logits = torch.matmul(text_embed, image_embed.t()) * math.exp(temp)
#symmetric loss function
labels = torch.arange(test_text.shape[0])
loss_text = loss_fn(logits, labels)
loss_image = loss_fn(logits.t(), labels)
loss = (loss_text + loss_image)/2
#compute accuracy
probs = torch.softmax(logits, dim=0)
winners = torch.argmax(probs, dim=0)
corrects = (winners == labels)
total_correct = corrects.sum().float().item()
print("Testing Loss:", loss.item(), "Testing Accuracy:", total_correct/test_text.shape[0])

Testing Loss: 3.9154810905456543 Testing Accuracy: 0.08
