In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoConfig


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)


In [3]:
prompt_tuned_summ_path = "./models/summ_pt.pt"
prompt_tuned_qna_path = "./models/qa_pt.pt"
promptuned_trans_model = "./models/translation_best"

### Loading the Models


In [4]:
promptuned_qa_model = torch.load(prompt_tuned_qna_path)
promptuned_summ_model = torch.load(prompt_tuned_summ_path)

In [5]:
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config
trans_config = AutoConfig.from_pretrained(promptuned_trans_model)
trans_model = GPT2Model.from_pretrained(promptuned_trans_model, config=trans_config)
trans_model.eval()


  return self.fget.__get__(instance, owner)()


GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

### Function for Getting Features

In [6]:
from tqdm import tqdm
promptuned_summ_model = promptuned_summ_model.to("cpu")
promptuned_qa_model = promptuned_qa_model.to("cpu")
trans_model = trans_model.to("cpu")
import numpy as np



def get_summ_feats(encoded_input):
        output = promptuned_summ_model(encoded_input)
        summ_row = 0
        for row in output[1]:
            for loc_row in row:
                # shape of loc row is 1, a, b, c
                # reshape it to 1, b, a*c
                reshaped = loc_row.reshape(1, loc_row.shape[2], loc_row.shape[1]*loc_row.shape[3])
                #print(summ_row)
                if type(summ_row) == int:
                    summ_row = reshaped
                else:
                    summ_row = summ_row + reshaped
        return summ_row
def get_qa_feats(encoded_input):
    output = promptuned_qa_model(encoded_input)
    summ_row = 0
    for row in output[1]:
        for loc_row in row:
            # shape of loc row is 1, a, b, c
            # reshape it to 1, b, a*c
            reshaped = loc_row.reshape(1, loc_row.shape[2], loc_row.shape[1]*loc_row.shape[3])
            #print(summ_row)
            if type(summ_row) == int:
                summ_row = reshaped
            else:
                summ_row = summ_row + reshaped
    return summ_row

def get_trans_feats(encoded_input):
    output = trans_model(encoded_input)
    output = output['last_hidden_state']
    return output

def get_feats(list_sents, max_length):
    summ_reps = []
    qa_reps = []
    trans_reps = []
    tokenizer.pad_token = tokenizer.eos_token
    for sent in tqdm(list_sents, total = len(list_sents)):
        input_s = tokenizer.encode(sent,return_tensors="pt", max_length=max_length, truncation=True, padding="max_length")   
        prompt = [0, 1, 2]
        # convert to tensor
        prompt = torch.tensor(prompt)
        # add batch dimension 1
        prompt = prompt.unsqueeze(0)
        input_s = torch.cat((prompt, input_s), 1)
        summ_feats = get_summ_feats(input_s)
        trans_feats = get_trans_feats(input_s)
        qa_feats = get_qa_feats(input_s)
        # convert shape from 1, x, 768 to x, 768
        summ_feats = summ_feats.squeeze(0)
        trans_feats = trans_feats.squeeze(0)
        qa_feats = qa_feats.squeeze(0)
        summ_feats = np.array(torch.mean(summ_feats, dim=0).detach().numpy())
        trans_feats = np.array(torch.mean(trans_feats, dim=0).detach().numpy())
        qa_feats = np.array(torch.mean(qa_feats, dim=0).detach().numpy())   
        summ_reps.append(summ_feats)
        qa_reps.append(qa_feats)
        trans_reps.append(trans_feats)
    return summ_reps, qa_reps, trans_reps        



### Loading the stim and fmri data

In [7]:
# print cwd
import os
print(os.getcwd())
working_dir = "csai_a5"
# switch to working dir
os.chdir(working_dir)

parent_dir = "assignment5"
stim_file = "stimuli.txt"
subj1 = "subj1.npy"
subj2 = "subj2.npy"
# load the stims file
import regex as re
with open(os.path.join(parent_dir, stim_file), "r") as f:
    stims = f.readlines()
final_stims = []
for item in stims:
    # remove all numbers and punctuation
    item = re.sub(r'[^\w\s]','',item)
    # remove all numbers
    item = re.sub(r'\d+', '', item)
    item = item.lower()
    final_stims.append(item)


/home2/advaith.malladi


### feature extraction

In [8]:
max_len = 0
for sent in final_stims:
    length= len(sent.split())
    if max_len < length:
        max_len  = length
print(max_len)
max_len = max_len + 5
summ_reps, qa_reps, trans_reps = get_feats(final_stims, max_len)

23


100%|█████████████████████████████████████████| 627/627 [00:44<00:00, 14.06it/s]


### Preparing FMRI data

In [9]:
# load fmri data of subj1 and subj2
import numpy as np
subj1_data = np.load(os.path.join(parent_dir, subj1), allow_pickle=True).item()
subj2_data = np.load(os.path.join(parent_dir, subj2), allow_pickle=True).item()

subj1_lang = subj1_data["language"]
subj1_vis = subj1_data["vision"]
subj1_dmn = subj1_data["dmn"]
subj1_task = subj1_data["task"]
subj2_lang = subj2_data["language"]
subj2_vis = subj2_data["vision"]
subj2_dmn = subj2_data["dmn"]
subj2_task = subj2_data["task"]
print(subj1_lang.shape, subj1_vis.shape, subj1_dmn.shape, subj1_task.shape, subj2_lang.shape, subj2_vis.shape, subj2_dmn.shape, subj2_task.shape)

(627, 11437) (627, 33792) (627, 17190) (627, 35120) (627, 10791) (627, 31109) (627, 15070) (627, 30594)


### Brain Decoding

In [10]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
# import R2 as well
from sklearn.metrics import r2_score
def cos_d(y1, y2):
    cos_sim = np.dot(y1, y2)/(np.linalg.norm(y1)*np.linalg.norm(y2))
    cos_d = 1 - cos_sim
    return cos_d
def get_2v2_accuracy(y_pred, y_test):
    N = y_pred.shape[0]
    tot_cnt = 0
    pos_cnt = 0
    for i in range(0, N):
        for j in range(1, N):
            yi_test = y_test[i]
            yi_pred = y_pred[i]
            yj_test = y_test[j]
            yj_pred = y_pred[j]
            score1 = cos_d(yi_test, yi_pred) + cos_d(yj_test, yj_pred)
            score2 = cos_d(yi_test, yj_pred) + cos_d(yj_test, yi_pred)
            if score1 < score2:
                pos_cnt += 1
            tot_cnt += 1
    return pos_cnt/tot_cnt

def get_avg_corr(y_pred, y_test):
    N = y_pred.shape[0]
    tot_corr = 0
    for i in range(0, N):
        corr = np.corrcoef(y_pred[i], y_test[i])[0,1]
        tot_corr += corr
    return tot_corr/N

def decoding_module(X, y):
    kf = KFold(n_splits=5)
    mse = []
    r2 = []
    acc = []
    acc_scores = []
    corr_scores = []
    for train_index, test_index in  kf.split(X) :
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        model = Ridge()
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        get_2v2_accuracy(y_pred, y_test)
        mse.append(mean_squared_error(y_test, y_pred))
        r2.append(r2_score(y_test, y_pred))
        acc_scores.append(get_2v2_accuracy(y_pred, y_test))
        corr_scores.append(get_avg_corr(y_pred, y_test))
    loc_dict = {}
    loc_dict["MSE"] = np.mean(mse)
    loc_dict["R2"] = np.mean(r2)
    loc_dict["Accuracy"] = np.mean(acc_scores)
    loc_dict["Correlation"] = np.mean(corr_scores)
    return loc_dict

    


subj1_decoding_scores = {}
subj2_decoding_scores = {}

In [11]:
list_feats = [summ_reps, qa_reps, trans_reps]
list_fmris = [subj1_lang, subj1_vis, subj1_dmn, subj1_task, subj2_lang, subj2_vis, subj2_dmn, subj2_task]
feats_names = ["prompt-tuned-summarization", "prompt-tuned-question-answering", "prompt-tuned-translation"]
fmri_names = ['subj1_lang', 'subj1_vis', 'subj1_dmn', 'subj1_task', 'subj2_lang', 'subj2_vis', 'subj2_dmn', 'subj2_task']
for i in tqdm(range(len(list_feats)), total = len(list_feats)):
    for j in range(len(list_fmris)):
        label = "decoding " + feats_names[i] + " with " + fmri_names[j]
        reps = list_feats[i]
        fmris = list_fmris[j]
        x = np.array(fmris)
        y = np.array(reps)
        decoding_scores = decoding_module(x, y)
        if j < 4:
            subj1_decoding_scores[label] = decoding_scores
        else:
            subj2_decoding_scores[label] = decoding_scores
        print(decoding_scores)
        
        


  0%|                                                     | 0/3 [00:00<?, ?it/s]

{'MSE': 0.38096174884382417, 'R2': -0.7733994461504705, 'Accuracy': 0.7547911930363543, 'Correlation': 0.9494759504220255}
{'MSE': 0.35968410024048136, 'R2': -0.6744477745129355, 'Accuracy': 0.7529632360471069, 'Correlation': 0.952180624266617}
{'MSE': 0.36020446803740624, 'R2': -0.6721452905994902, 'Accuracy': 0.7181337429595495, 'Correlation': 0.952121002882906}
{'MSE': 0.31262654452950556, 'R2': -0.4470878815744852, 'Accuracy': 0.740579211469534, 'Correlation': 0.9581947149771614}
{'MSE': 0.4139036434330185, 'R2': -0.9271034626312931, 'Accuracy': 0.727136712749616, 'Correlation': 0.9452519543733832}
{'MSE': 0.3945455932026454, 'R2': -0.8340644573931988, 'Accuracy': 0.7316608294930875, 'Correlation': 0.9476600053016186}
{'MSE': 0.3998100960872937, 'R2': -0.8595109074048943, 'Accuracy': 0.7142044034818229, 'Correlation': 0.947060189286445}


 33%|███████████████                              | 1/3 [00:39<01:19, 39.76s/it]

{'MSE': 0.35077004519126753, 'R2': -0.6250790389973478, 'Accuracy': 0.7126539682539683, 'Correlation': 0.9532297534357799}
{'MSE': 0.38096174884382417, 'R2': -0.7733994461504705, 'Accuracy': 0.7547911930363543, 'Correlation': 0.9494759504220255}
{'MSE': 0.35968410024048136, 'R2': -0.6744477745129355, 'Accuracy': 0.7529632360471069, 'Correlation': 0.952180624266617}
{'MSE': 0.36020446803740624, 'R2': -0.6721452905994902, 'Accuracy': 0.7181337429595495, 'Correlation': 0.952121002882906}
{'MSE': 0.31262654452950556, 'R2': -0.4470878815744852, 'Accuracy': 0.740579211469534, 'Correlation': 0.9581947149771614}
{'MSE': 0.4139036434330185, 'R2': -0.9271034626312931, 'Accuracy': 0.727136712749616, 'Correlation': 0.9452519543733832}
{'MSE': 0.3945455932026454, 'R2': -0.8340644573931988, 'Accuracy': 0.7316608294930875, 'Correlation': 0.9476600053016186}
{'MSE': 0.3998100960872937, 'R2': -0.8595109074048943, 'Accuracy': 0.7142044034818229, 'Correlation': 0.947060189286445}


 67%|██████████████████████████████               | 2/3 [01:17<00:38, 38.47s/it]

{'MSE': 0.35077004519126753, 'R2': -0.6250790389973478, 'Accuracy': 0.7126539682539683, 'Correlation': 0.9532297534357799}
{'MSE': 0.20043854411256662, 'R2': -0.584358836080643, 'Accuracy': 0.7093788018433179, 'Correlation': 0.9990300722133689}
{'MSE': 0.1971109401070051, 'R2': -0.5029438282408814, 'Accuracy': 0.7090154633896569, 'Correlation': 0.9990697503991168}
{'MSE': 0.21524041502542998, 'R2': -0.49894688489037414, 'Accuracy': 0.6617681515616999, 'Correlation': 0.9989151142739155}
{'MSE': 0.19265129045602353, 'R2': -0.3191236483584917, 'Accuracy': 0.697053149001536, 'Correlation': 0.9991093356539972}
{'MSE': 0.22888943881375137, 'R2': -0.6966292774174742, 'Accuracy': 0.6606990271377368, 'Correlation': 0.998842165998596}
{'MSE': 0.24006601320970833, 'R2': -0.6141613197891107, 'Accuracy': 0.665368561187916, 'Correlation': 0.9988804338023366}
{'MSE': 0.22820536648596637, 'R2': -0.6668713014982955, 'Accuracy': 0.6383490015360984, 'Correlation': 0.9988361922414836}


100%|█████████████████████████████████████████████| 3/3 [01:55<00:00, 38.46s/it]

{'MSE': 0.21055305811577196, 'R2': -0.4823323738101248, 'Accuracy': 0.6451535074244752, 'Correlation': 0.9989869502257841}





### Brain Encoding

In [12]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
# import R2 as well
from sklearn.metrics import r2_score
def cos_d(y1, y2):
    cos_sim = np.dot(y1, y2)/(np.linalg.norm(y1)*np.linalg.norm(y2))
    cos_d = 1 - cos_sim
    return cos_d
def get_2v2_accuracy(y_pred, y_test):
    N = y_pred.shape[0]
    tot_cnt = 0
    pos_cnt = 0
    for i in range(0, N):
        for j in range(1, N):
            yi_test = y_test[i]
            yi_pred = y_pred[i]
            yj_test = y_test[j]
            yj_pred = y_pred[j]
            score1 = cos_d(yi_test, yi_pred) + cos_d(yj_test, yj_pred)
            score2 = cos_d(yi_test, yj_pred) + cos_d(yj_test, yi_pred)
            if score1 < score2:
                pos_cnt += 1
            tot_cnt += 1
    return pos_cnt/tot_cnt

def get_avg_corr(y_pred, y_test):
    N = y_pred.shape[0]
    tot_corr = 0
    for i in range(0, N):
        corr = np.corrcoef(y_pred[i], y_test[i])[0,1]
        tot_corr += corr
    return tot_corr/N

def encoding_module(X, y):
    kf = KFold(n_splits=5)
    mse = []
    r2 = []
    acc = []
    acc_scores = []
    corr_scores = []
    for train_index, test_index in  kf.split(X) :
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        model = Ridge()
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        get_2v2_accuracy(y_pred, y_test)
        mse.append(mean_squared_error(y_test, y_pred))
        r2.append(r2_score(y_test, y_pred))
        acc_scores.append(get_2v2_accuracy(y_pred, y_test))
        corr_scores.append(get_avg_corr(y_pred, y_test))
    loc_dict = {}
    loc_dict["MSE"] = np.mean(mse)
    loc_dict["R2"] = np.mean(r2)
    loc_dict["Accuracy"] = np.mean(acc_scores)
    loc_dict["Correlation"] = np.mean(corr_scores)
    return loc_dict

    


subj1_encoding_scores = {}
subj2_encoding_scores = {}

In [13]:
list_feats = [summ_reps, qa_reps, trans_reps]
list_fmris = [subj1_lang, subj1_vis, subj1_dmn, subj1_task, subj2_lang, subj2_vis, subj2_dmn, subj2_task]
feats_names = ["prompt-tuned-summarization", "prompt-tuned-question-answering", "prompt-tuned-translation"]
fmri_names = ['subj1_lang', 'subj1_vis', 'subj1_dmn', 'subj1_task', 'subj2_lang', 'subj2_vis', 'subj2_dmn', 'subj2_task']
for i in tqdm(range(len(list_feats)), total = len(list_feats)):
    for j in range(len(list_fmris)):
        label = "encoding " + feats_names[i] + " with " + fmri_names[j]
        reps = list_feats[i]
        fmris = list_fmris[j]
        x = np.array(fmris)
        y = np.array(reps)
        encoding_scores = encoding_module(x, y)
        if j < 4:
            subj1_encoding_scores[label] = encoding_scores
        else:
            subj2_encoding_scores[label] = encoding_scores
        print(encoding_scores)
        
        


  0%|                                                     | 0/3 [00:00<?, ?it/s]

{'MSE': 0.38096174884382417, 'R2': -0.7733994461504705, 'Accuracy': 0.7547911930363543, 'Correlation': 0.9494759504220255}
{'MSE': 0.35968410024048136, 'R2': -0.6744477745129355, 'Accuracy': 0.7529632360471069, 'Correlation': 0.952180624266617}
{'MSE': 0.36020446803740624, 'R2': -0.6721452905994902, 'Accuracy': 0.7181337429595495, 'Correlation': 0.952121002882906}
{'MSE': 0.31262654452950556, 'R2': -0.4470878815744852, 'Accuracy': 0.740579211469534, 'Correlation': 0.9581947149771614}
{'MSE': 0.4139036434330185, 'R2': -0.9271034626312931, 'Accuracy': 0.727136712749616, 'Correlation': 0.9452519543733832}
{'MSE': 0.3945455932026454, 'R2': -0.8340644573931988, 'Accuracy': 0.7316608294930875, 'Correlation': 0.9476600053016186}
{'MSE': 0.3998100960872937, 'R2': -0.8595109074048943, 'Accuracy': 0.7142044034818229, 'Correlation': 0.947060189286445}


 33%|███████████████                              | 1/3 [00:37<01:15, 37.72s/it]

{'MSE': 0.35077004519126753, 'R2': -0.6250790389973478, 'Accuracy': 0.7126539682539683, 'Correlation': 0.9532297534357799}
{'MSE': 0.38096174884382417, 'R2': -0.7733994461504705, 'Accuracy': 0.7547911930363543, 'Correlation': 0.9494759504220255}
{'MSE': 0.35968410024048136, 'R2': -0.6744477745129355, 'Accuracy': 0.7529632360471069, 'Correlation': 0.952180624266617}
{'MSE': 0.36020446803740624, 'R2': -0.6721452905994902, 'Accuracy': 0.7181337429595495, 'Correlation': 0.952121002882906}
{'MSE': 0.31262654452950556, 'R2': -0.4470878815744852, 'Accuracy': 0.740579211469534, 'Correlation': 0.9581947149771614}
{'MSE': 0.4139036434330185, 'R2': -0.9271034626312931, 'Accuracy': 0.727136712749616, 'Correlation': 0.9452519543733832}
{'MSE': 0.3945455932026454, 'R2': -0.8340644573931988, 'Accuracy': 0.7316608294930875, 'Correlation': 0.9476600053016186}
{'MSE': 0.3998100960872937, 'R2': -0.8595109074048943, 'Accuracy': 0.7142044034818229, 'Correlation': 0.947060189286445}


 67%|██████████████████████████████               | 2/3 [01:15<00:37, 37.78s/it]

{'MSE': 0.35077004519126753, 'R2': -0.6250790389973478, 'Accuracy': 0.7126539682539683, 'Correlation': 0.9532297534357799}
{'MSE': 0.20043854411256662, 'R2': -0.584358836080643, 'Accuracy': 0.7093788018433179, 'Correlation': 0.9990300722133689}
{'MSE': 0.1971109401070051, 'R2': -0.5029438282408814, 'Accuracy': 0.7090154633896569, 'Correlation': 0.9990697503991168}
{'MSE': 0.21524041502542998, 'R2': -0.49894688489037414, 'Accuracy': 0.6617681515616999, 'Correlation': 0.9989151142739155}
{'MSE': 0.19265129045602353, 'R2': -0.3191236483584917, 'Accuracy': 0.697053149001536, 'Correlation': 0.9991093356539972}
{'MSE': 0.22888943881375137, 'R2': -0.6966292774174742, 'Accuracy': 0.6606990271377368, 'Correlation': 0.998842165998596}
{'MSE': 0.24006601320970833, 'R2': -0.6141613197891107, 'Accuracy': 0.665368561187916, 'Correlation': 0.9988804338023366}
{'MSE': 0.22820536648596637, 'R2': -0.6668713014982955, 'Accuracy': 0.6383490015360984, 'Correlation': 0.9988361922414836}


100%|█████████████████████████████████████████████| 3/3 [01:53<00:00, 37.79s/it]

{'MSE': 0.21055305811577196, 'R2': -0.4823323738101248, 'Accuracy': 0.6451535074244752, 'Correlation': 0.9989869502257841}





In [14]:

os.chdir("scores")
import json
with open("pt_subj1_encoding_scores.json", "w") as f:
    json.dump(subj1_encoding_scores, f, indent=4)
with open("pt_subj2_encoding_scores.json", "w") as f:
    json.dump(subj2_encoding_scores, f, indent=4)
with open("pt_subj1_decoding_scores.json", "w") as f:
    json.dump(subj1_decoding_scores, f, indent=4)
with open("pt_subj2_decoding_scores.json", "w") as f:
    json.dump(subj2_decoding_scores, f, indent=4)

In [15]:
# print cwd
print(os.getcwd())

/home2/advaith.malladi/csai_a5/scores
