In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoConfig


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)


In [3]:
fine_tuned_summ_path = "./models/summ_ft.pt"
fine_tuned_qna_path = "./models/qa_ft.pt"
fine_tuned_trans_model = "./models/trans_ft.pt"

In [4]:
finetuned_qa_model = torch.load(fine_tuned_qna_path)
finetuned_summ_model = torch.load(fine_tuned_summ_path)
finetuned_trans_model = torch.load(fine_tuned_trans_model)

### Func for Feature Extraction

In [5]:
from tqdm import tqdm
finetuned_qa_model = finetuned_qa_model.to("cpu")
finetuned_summ_model = finetuned_summ_model.to("cpu")
finetuned_trans_model = finetuned_trans_model.to("cpu")
def get_summ_feats(encoded_input):
        output = finetuned_summ_model(encoded_input)
        summ_row = 0
        for row in output[1]:
            for loc_row in row:
                # shape of loc row is 1, a, b, c
                # reshape it to 1, b, a*c
                reshaped = loc_row.reshape(1, loc_row.shape[2], loc_row.shape[1]*loc_row.shape[3])
                #print(summ_row)
                if type(summ_row) == int:
                    summ_row = reshaped
                else:
                    summ_row = summ_row + reshaped
        return summ_row
    
def get_qa_feats(encoded_input):
    output = finetuned_qa_model(encoded_input)
    summ_row = 0
    for row in output[1]:
        for loc_row in row:
            # shape of loc row is 1, a, b, c
            # reshape it to 1, b, a*c
            reshaped = loc_row.reshape(1, loc_row.shape[2], loc_row.shape[1]*loc_row.shape[3])
            #print(summ_row)
            if type(summ_row) == int:
                summ_row = reshaped
            else:
                summ_row = summ_row + reshaped
    return summ_row

def get_trans_feats(encoded_input):
    output = finetuned_trans_model(encoded_input)
    summ_row = 0
    for row in output[1]:
        for loc_row in row:
            # shape of loc row is 1, a, b, c
            # reshape it to 1, b, a*c
            reshaped = loc_row.reshape(1, loc_row.shape[2], loc_row.shape[1]*loc_row.shape[3])
            #print(summ_row)
            if type(summ_row) == int:
                summ_row = reshaped
            else:
                summ_row = summ_row + reshaped
    return summ_row

def get_feats(list_sents, max_length):
    summ_reps = []
    qa_reps = []
    trans_reps = []
    tokenizer.pad_token = tokenizer.eos_token
    for sent in tqdm(list_sents, total = len(list_sents)):
        input_s = tokenizer.encode(sent,return_tensors="pt", max_length=max_length, truncation=True, padding="max_length")   
        summ_feats = get_summ_feats(input_s)
        trans_feats = get_trans_feats(input_s)
        qa_feats = get_qa_feats(input_s)
        # convert shape from 1, x, 768 to x, 768
        summ_feats = summ_feats.squeeze(0)
        trans_feats = trans_feats.squeeze(0)
        qa_feats = qa_feats.squeeze(0)
        summ_feats = np.array(torch.mean(summ_feats, dim=0).detach().numpy())
        trans_feats = np.array(torch.mean(trans_feats, dim=0).detach().numpy())
        qa_feats = np.array(torch.mean(qa_feats, dim=0).detach().numpy())   
        summ_reps.append(summ_feats)
        qa_reps.append(qa_feats)
        trans_reps.append(trans_feats)
    return summ_reps, qa_reps, trans_reps  


### Loading the stim and fmri data

In [6]:
# print cwd
import os
print(os.getcwd())
working_dir = "csai_a5"
# switch to working dir
os.chdir(working_dir)

parent_dir = "assignment5"
stim_file = "stimuli.txt"
subj1 = "subj1.npy"
subj2 = "subj2.npy"
# load the stims file
import regex as re
with open(os.path.join(parent_dir, stim_file), "r") as f:
    stims = f.readlines()
final_stims = []
for item in stims:
    # remove all numbers and punctuation
    item = re.sub(r'[^\w\s]','',item)
    # remove all numbers
    item = re.sub(r'\d+', '', item)
    item = item.lower()
    final_stims.append(item)


/home2/advaith.malladi


### feature extraction

In [7]:
max_len = 0
import numpy as np
for sent in final_stims:
    length= len(sent.split())
    if max_len < length:
        max_len  = length
print(max_len)
max_len = max_len + 5
summ_reps, qa_reps, trans_reps = get_feats(final_stims, max_len)

23


100%|█████████████████████████████████████████| 627/627 [00:49<00:00, 12.64it/s]


### preparing fmri data

In [8]:
# load fmri data of subj1 and subj2
import numpy as np
subj1_data = np.load(os.path.join(parent_dir, subj1), allow_pickle=True).item()
subj2_data = np.load(os.path.join(parent_dir, subj2), allow_pickle=True).item()

subj1_lang = subj1_data["language"]
subj1_vis = subj1_data["vision"]
subj1_dmn = subj1_data["dmn"]
subj1_task = subj1_data["task"]
subj2_lang = subj2_data["language"]
subj2_vis = subj2_data["vision"]
subj2_dmn = subj2_data["dmn"]
subj2_task = subj2_data["task"]
print(subj1_lang.shape, subj1_vis.shape, subj1_dmn.shape, subj1_task.shape, subj2_lang.shape, subj2_vis.shape, subj2_dmn.shape, subj2_task.shape)

(627, 11437) (627, 33792) (627, 17190) (627, 35120) (627, 10791) (627, 31109) (627, 15070) (627, 30594)


In [9]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
# import R2 as well
from sklearn.metrics import r2_score
def cos_d(y1, y2):
    cos_sim = np.dot(y1, y2)/(np.linalg.norm(y1)*np.linalg.norm(y2))
    cos_d = 1 - cos_sim
    return cos_d
def get_2v2_accuracy(y_pred, y_test):
    N = y_pred.shape[0]
    tot_cnt = 0
    pos_cnt = 0
    for i in range(0, N):
        for j in range(1, N):
            yi_test = y_test[i]
            yi_pred = y_pred[i]
            yj_test = y_test[j]
            yj_pred = y_pred[j]
            score1 = cos_d(yi_test, yi_pred) + cos_d(yj_test, yj_pred)
            score2 = cos_d(yi_test, yj_pred) + cos_d(yj_test, yi_pred)
            if score1 < score2:
                pos_cnt += 1
            tot_cnt += 1
    return pos_cnt/tot_cnt

def get_avg_corr(y_pred, y_test):
    N = y_pred.shape[0]
    tot_corr = 0
    for i in range(0, N):
        corr = np.corrcoef(y_pred[i], y_test[i])[0,1]
        tot_corr += corr
    return tot_corr/N

def decoding_module(X, y):
    kf = KFold(n_splits=5)
    mse = []
    r2 = []
    acc = []
    acc_scores = []
    corr_scores = []
    for train_index, test_index in  kf.split(X) :
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        model = Ridge()
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        get_2v2_accuracy(y_pred, y_test)
        mse.append(mean_squared_error(y_test, y_pred))
        r2.append(r2_score(y_test, y_pred))
        acc_scores.append(get_2v2_accuracy(y_pred, y_test))
        corr_scores.append(get_avg_corr(y_pred, y_test))
    loc_dict = {}
    loc_dict["MSE"] = np.mean(mse)
    loc_dict["R2"] = np.mean(r2)
    loc_dict["Accuracy"] = np.mean(acc_scores)
    loc_dict["Correlation"] = np.mean(corr_scores)
    return loc_dict

    


subj1_decoding_scores = {}
subj2_decoding_scores = {}

### Brain Decoding

In [10]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
# import R2 as well
from sklearn.metrics import r2_score
def cos_d(y1, y2):
    cos_sim = np.dot(y1, y2)/(np.linalg.norm(y1)*np.linalg.norm(y2))
    cos_d = 1 - cos_sim
    return cos_d
def get_2v2_accuracy(y_pred, y_test):
    N = y_pred.shape[0]
    tot_cnt = 0
    pos_cnt = 0
    for i in range(0, N):
        for j in range(1, N):
            yi_test = y_test[i]
            yi_pred = y_pred[i]
            yj_test = y_test[j]
            yj_pred = y_pred[j]
            score1 = cos_d(yi_test, yi_pred) + cos_d(yj_test, yj_pred)
            score2 = cos_d(yi_test, yj_pred) + cos_d(yj_test, yi_pred)
            if score1 < score2:
                pos_cnt += 1
            tot_cnt += 1
    return pos_cnt/tot_cnt

def get_avg_corr(y_pred, y_test):
    N = y_pred.shape[0]
    tot_corr = 0
    for i in range(0, N):
        corr = np.corrcoef(y_pred[i], y_test[i])[0,1]
        tot_corr += corr
    return tot_corr/N

def decoding_module(X, y):
    kf = KFold(n_splits=5)
    mse = []
    r2 = []
    acc = []
    acc_scores = []
    corr_scores = []
    for train_index, test_index in  kf.split(X) :
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        model = Ridge()
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        get_2v2_accuracy(y_pred, y_test)
        mse.append(mean_squared_error(y_test, y_pred))
        r2.append(r2_score(y_test, y_pred))
        acc_scores.append(get_2v2_accuracy(y_pred, y_test))
        corr_scores.append(get_avg_corr(y_pred, y_test))
    loc_dict = {}
    loc_dict["MSE"] = np.mean(mse)
    loc_dict["R2"] = np.mean(r2)
    loc_dict["Accuracy"] = np.mean(acc_scores)
    loc_dict["Correlation"] = np.mean(corr_scores)
    return loc_dict

    


subj1_decoding_scores = {}
subj2_decoding_scores = {}

In [11]:
list_feats = [summ_reps, qa_reps, trans_reps]
list_fmris = [subj1_lang, subj1_vis, subj1_dmn, subj1_task, subj2_lang, subj2_vis, subj2_dmn, subj2_task]
feats_names = ["prompt-tuned-summarization", "prompt-tuned-question-answering", "prompt-tuned-translation"]
fmri_names = ['subj1_lang', 'subj1_vis', 'subj1_dmn', 'subj1_task', 'subj2_lang', 'subj2_vis', 'subj2_dmn', 'subj2_task']
for i in tqdm(range(len(list_feats)), total = len(list_feats)):
    for j in range(len(list_fmris)):
        label = "decoding " + feats_names[i] + " with " + fmri_names[j]
        reps = list_feats[i]
        fmris = list_fmris[j]
        x = np.array(fmris)
        y = np.array(reps)
        decoding_scores = decoding_module(x, y)
        if j < 4:
            subj1_decoding_scores[label] = decoding_scores
        else:
            subj2_decoding_scores[label] = decoding_scores
        print(decoding_scores)
        
        


  0%|                                                     | 0/3 [00:00<?, ?it/s]

{'MSE': 0.41692516945067587, 'R2': -0.8019612808247804, 'Accuracy': 0.673699743983615, 'Correlation': 0.9544017007764758}
{'MSE': 0.39084791091594606, 'R2': -0.6896086010065352, 'Accuracy': 0.6787705069124425, 'Correlation': 0.9571530443214094}
{'MSE': 0.3938045747674267, 'R2': -0.6971755711243084, 'Accuracy': 0.6303516641065028, 'Correlation': 0.9568386984066433}
{'MSE': 0.34106144839845903, 'R2': -0.46737424504526703, 'Accuracy': 0.6436487455197133, 'Correlation': 0.9624091724452206}
{'MSE': 0.4564297108846068, 'R2': -0.9738772939264088, 'Accuracy': 0.6318384024577574, 'Correlation': 0.9502035474353422}
{'MSE': 0.43658729847395855, 'R2': -0.883764801655705, 'Accuracy': 0.6313110087045571, 'Correlation': 0.9522742495151993}
{'MSE': 0.442553782367756, 'R2': -0.9131307502611111, 'Accuracy': 0.6039723502304147, 'Correlation': 0.951676027193234}


 33%|███████████████                              | 1/3 [00:38<01:17, 38.78s/it]

{'MSE': 0.38321955636877814, 'R2': -0.651942407208009, 'Accuracy': 0.6319266769073221, 'Correlation': 0.9579453815496233}
{'MSE': 0.419272794888068, 'R2': -0.8096041084448489, 'Accuracy': 0.6524940092165898, 'Correlation': 0.9544916960618728}
{'MSE': 0.39466602043034593, 'R2': -0.7027753948125922, 'Accuracy': 0.6457745007680492, 'Correlation': 0.9570191542898444}
{'MSE': 0.39719917784617065, 'R2': -0.709646130766465, 'Accuracy': 0.6028923707117255, 'Correlation': 0.956723694549883}
{'MSE': 0.34327483559744476, 'R2': -0.4752228768488817, 'Accuracy': 0.6298750640040962, 'Correlation': 0.9624611233021264}
{'MSE': 0.45814226800481517, 'R2': -0.9781965060677544, 'Accuracy': 0.6253650793650793, 'Correlation': 0.9503129674625115}
{'MSE': 0.4341286481766932, 'R2': -0.8715602041796604, 'Accuracy': 0.6388729134664619, 'Correlation': 0.9528045623275956}
{'MSE': 0.44237874163702573, 'R2': -0.9103170684459837, 'Accuracy': 0.6024202764976959, 'Correlation': 0.9519984689888886}


 67%|██████████████████████████████               | 2/3 [01:18<00:39, 39.45s/it]

{'MSE': 0.38442586414089824, 'R2': -0.6578138943462977, 'Accuracy': 0.6226164874551972, 'Correlation': 0.9581040126827242}
{'MSE': 0.4125637011300068, 'R2': -0.8093906454333762, 'Accuracy': 0.6661824884792626, 'Correlation': 0.9498060662601333}
{'MSE': 0.38763659582470406, 'R2': -0.700535513373924, 'Accuracy': 0.672794674859191, 'Correlation': 0.9527344753658804}
{'MSE': 0.3891405586023013, 'R2': -0.7023881561868409, 'Accuracy': 0.6258461853558628, 'Correlation': 0.9525146528186641}
{'MSE': 0.3333748383002843, 'R2': -0.45542660792621137, 'Accuracy': 0.6779006656426011, 'Correlation': 0.9590640741140166}
{'MSE': 0.45179643834667454, 'R2': -0.9820464806508887, 'Accuracy': 0.6273511520737327, 'Correlation': 0.9451405046345769}
{'MSE': 0.4300895490380466, 'R2': -0.8835103301434867, 'Accuracy': 0.6233701996927803, 'Correlation': 0.9476669012705873}
{'MSE': 0.4347281196997847, 'R2': -0.9069411499493547, 'Accuracy': 0.6050431131592422, 'Correlation': 0.9471308867966709}


100%|█████████████████████████████████████████████| 3/3 [01:56<00:00, 38.92s/it]

{'MSE': 0.3806434855816832, 'R2': -0.6662523945201411, 'Accuracy': 0.6107594470046083, 'Correlation': 0.953483643996106}





### Brain Encoding

In [12]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
# import R2 as well
from sklearn.metrics import r2_score
def cos_d(y1, y2):
    cos_sim = np.dot(y1, y2)/(np.linalg.norm(y1)*np.linalg.norm(y2))
    cos_d = 1 - cos_sim
    return cos_d
def get_2v2_accuracy(y_pred, y_test):
    N = y_pred.shape[0]
    tot_cnt = 0
    pos_cnt = 0
    for i in range(0, N):
        for j in range(1, N):
            yi_test = y_test[i]
            yi_pred = y_pred[i]
            yj_test = y_test[j]
            yj_pred = y_pred[j]
            score1 = cos_d(yi_test, yi_pred) + cos_d(yj_test, yj_pred)
            score2 = cos_d(yi_test, yj_pred) + cos_d(yj_test, yi_pred)
            if score1 < score2:
                pos_cnt += 1
            tot_cnt += 1
    return pos_cnt/tot_cnt

def get_avg_corr(y_pred, y_test):
    N = y_pred.shape[0]
    tot_corr = 0
    for i in range(0, N):
        corr = np.corrcoef(y_pred[i], y_test[i])[0,1]
        tot_corr += corr
    return tot_corr/N

def encoding_module(X, y):
    kf = KFold(n_splits=5)
    mse = []
    r2 = []
    acc = []
    acc_scores = []
    corr_scores = []
    for train_index, test_index in  kf.split(X) :
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        model = Ridge()
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        get_2v2_accuracy(y_pred, y_test)
        mse.append(mean_squared_error(y_test, y_pred))
        r2.append(r2_score(y_test, y_pred))
        acc_scores.append(get_2v2_accuracy(y_pred, y_test))
        corr_scores.append(get_avg_corr(y_pred, y_test))
    loc_dict = {}
    loc_dict["MSE"] = np.mean(mse)
    loc_dict["R2"] = np.mean(r2)
    loc_dict["Accuracy"] = np.mean(acc_scores)
    loc_dict["Correlation"] = np.mean(corr_scores)
    return loc_dict

    


subj1_encoding_scores = {}
subj2_encoding_scores = {}

In [13]:
list_feats = [summ_reps, qa_reps, trans_reps]
list_fmris = [subj1_lang, subj1_vis, subj1_dmn, subj1_task, subj2_lang, subj2_vis, subj2_dmn, subj2_task]
feats_names = ["prompt-tuned-summarization", "prompt-tuned-question-answering", "prompt-tuned-translation"]
fmri_names = ['subj1_lang', 'subj1_vis', 'subj1_dmn', 'subj1_task', 'subj2_lang', 'subj2_vis', 'subj2_dmn', 'subj2_task']
for i in tqdm(range(len(list_feats)), total = len(list_feats)):
    for j in range(len(list_fmris)):
        label = "encoding " + feats_names[i] + " with " + fmri_names[j]
        reps = list_feats[i]
        fmris = list_fmris[j]
        x = np.array(fmris)
        y = np.array(reps)
        encoding_scores = encoding_module(x, y)
        if j < 4:
            subj1_encoding_scores[label] = encoding_scores
        else:
            subj2_encoding_scores[label] = encoding_scores
        print(encoding_scores)
        
        


  0%|                                                     | 0/3 [00:00<?, ?it/s]

{'MSE': 0.41692516945067587, 'R2': -0.8019612808247804, 'Accuracy': 0.673699743983615, 'Correlation': 0.9544017007764758}
{'MSE': 0.39084791091594606, 'R2': -0.6896086010065352, 'Accuracy': 0.6787705069124425, 'Correlation': 0.9571530443214094}
{'MSE': 0.3938045747674267, 'R2': -0.6971755711243084, 'Accuracy': 0.6303516641065028, 'Correlation': 0.9568386984066433}
{'MSE': 0.34106144839845903, 'R2': -0.46737424504526703, 'Accuracy': 0.6436487455197133, 'Correlation': 0.9624091724452206}
{'MSE': 0.4564297108846068, 'R2': -0.9738772939264088, 'Accuracy': 0.6318384024577574, 'Correlation': 0.9502035474353422}
{'MSE': 0.43658729847395855, 'R2': -0.883764801655705, 'Accuracy': 0.6313110087045571, 'Correlation': 0.9522742495151993}
{'MSE': 0.442553782367756, 'R2': -0.9131307502611111, 'Accuracy': 0.6039723502304147, 'Correlation': 0.951676027193234}


 33%|███████████████                              | 1/3 [00:37<01:14, 37.31s/it]

{'MSE': 0.38321955636877814, 'R2': -0.651942407208009, 'Accuracy': 0.6319266769073221, 'Correlation': 0.9579453815496233}
{'MSE': 0.419272794888068, 'R2': -0.8096041084448489, 'Accuracy': 0.6524940092165898, 'Correlation': 0.9544916960618728}
{'MSE': 0.39466602043034593, 'R2': -0.7027753948125922, 'Accuracy': 0.6457745007680492, 'Correlation': 0.9570191542898444}
{'MSE': 0.39719917784617065, 'R2': -0.709646130766465, 'Accuracy': 0.6028923707117255, 'Correlation': 0.956723694549883}
{'MSE': 0.34327483559744476, 'R2': -0.4752228768488817, 'Accuracy': 0.6298750640040962, 'Correlation': 0.9624611233021264}
{'MSE': 0.45814226800481517, 'R2': -0.9781965060677544, 'Accuracy': 0.6253650793650793, 'Correlation': 0.9503129674625115}
{'MSE': 0.4341286481766932, 'R2': -0.8715602041796604, 'Accuracy': 0.6388729134664619, 'Correlation': 0.9528045623275956}
{'MSE': 0.44237874163702573, 'R2': -0.9103170684459837, 'Accuracy': 0.6024202764976959, 'Correlation': 0.9519984689888886}


 67%|██████████████████████████████               | 2/3 [01:14<00:37, 37.17s/it]

{'MSE': 0.38442586414089824, 'R2': -0.6578138943462977, 'Accuracy': 0.6226164874551972, 'Correlation': 0.9581040126827242}
{'MSE': 0.4125637011300068, 'R2': -0.8093906454333762, 'Accuracy': 0.6661824884792626, 'Correlation': 0.9498060662601333}
{'MSE': 0.38763659582470406, 'R2': -0.700535513373924, 'Accuracy': 0.672794674859191, 'Correlation': 0.9527344753658804}
{'MSE': 0.3891405586023013, 'R2': -0.7023881561868409, 'Accuracy': 0.6258461853558628, 'Correlation': 0.9525146528186641}
{'MSE': 0.3333748383002843, 'R2': -0.45542660792621137, 'Accuracy': 0.6779006656426011, 'Correlation': 0.9590640741140166}
{'MSE': 0.45179643834667454, 'R2': -0.9820464806508887, 'Accuracy': 0.6273511520737327, 'Correlation': 0.9451405046345769}
{'MSE': 0.4300895490380466, 'R2': -0.8835103301434867, 'Accuracy': 0.6233701996927803, 'Correlation': 0.9476669012705873}
{'MSE': 0.4347281196997847, 'R2': -0.9069411499493547, 'Accuracy': 0.6050431131592422, 'Correlation': 0.9471308867966709}


100%|█████████████████████████████████████████████| 3/3 [01:51<00:00, 37.19s/it]

{'MSE': 0.3806434855816832, 'R2': -0.6662523945201411, 'Accuracy': 0.6107594470046083, 'Correlation': 0.953483643996106}





In [15]:
# save subj1_encoding_scores as subj1_encoding_scores.json with indent 4
# make dir scoreas and move to that dir
os.chdir("scores")
import json
with open("ft_subj1_encoding_scores.json", "w") as f:
    json.dump(subj1_encoding_scores, f, indent=4)
with open("ft_subj2_encoding_scores.json", "w") as f:
    json.dump(subj2_encoding_scores, f, indent=4)
with open("ft_subj1_decoding_scores.json", "w") as f:
    json.dump(subj1_decoding_scores, f, indent=4)
with open("ft_subj2_decoding_scores.json", "w") as f:
    json.dump(subj2_decoding_scores, f, indent=4)

FileNotFoundError: [Errno 2] No such file or directory: 'scores'