In [6]:
import pandas as pd

folder = 'experiments/neg_dev_matched'

binary_finetune_pos_preds = pd.read_csv(folder + '/binary_finetune_pos_preds.tsv', header=None, delimiter='\t')
binary_finetune_pos_preds=binary_finetune_pos_preds.rename({0: "pos"}, axis=1)

binary_finetune_neg_preds = pd.read_csv(folder + '/binary_finetune_neg_preds.tsv', header=None, delimiter='\t')
binary_finetune_neg_preds=binary_finetune_neg_preds.rename({0: "neg"}, axis=1)

full_finetune_pos_preds = pd.read_csv(folder + '/full_two_way_finetune_pos_preds.tsv', header=None, delimiter='\t')
full_finetune_pos_preds=full_finetune_pos_preds.rename({0: "pos"}, axis=1)

full_finetune_neg_preds = pd.read_csv(folder + '/full_two_way_finetune_neg_preds.tsv', header=None, delimiter='\t')
full_finetune_neg_preds=full_finetune_neg_preds.rename({0: "neg"}, axis=1)

examples = pd.read_csv('neg_dev_matched.tsv', delimiter='\t')
binary_preds = pd.concat([binary_finetune_pos_preds, binary_finetune_neg_preds, examples], axis=1)
binary_preds['pos'] = binary_preds.apply(lambda row: 'c' if row['pos'] == 0 else 'e', axis=1)
binary_preds['neg'] = binary_preds.apply(lambda row: 'c' if row['neg'] == 0 else 'e', axis=1)
binary_preds['change'] = binary_preds.apply(lambda row: 1 if row['pos'] != row['neg'] else 0, axis=1)

full_preds = pd.concat([full_finetune_pos_preds, full_finetune_neg_preds, examples], axis=1)
full_preds['pos'] = full_preds.apply(lambda row: 'n' if row['pos'] == 0 else 'e', axis=1)
full_preds['neg'] = full_preds.apply(lambda row: 'n' if row['neg'] == 0 else 'e', axis=1)
full_preds['change'] = full_preds.apply(lambda row: 1 if row['pos'] != row['neg'] else 0, axis=1)

def combine_preds(row):
    return row['pos']+row['neg']

binary_preds['change_type'] = binary_preds.apply(combine_preds, axis=1)
full_preds['change_type'] = full_preds.apply(combine_preds, axis=1)


# display(binary_preds['change_type'].value_counts())
# display(full_preds['change_type'].value_counts())

print("percentage of initial examples: ")
print("binary: ")
total = len(binary_preds)
print(binary_preds['pos'].value_counts()/total)

print("full: ")
total = len(full_preds)
print(full_preds['pos'].value_counts()/total)


subset = full_preds[full_preds['pos']=="e"]
display(subset['neg'].value_counts()/len(subset))


percentage of initial examples: 
binary: 
e    0.528623
c    0.471377
Name: pos, dtype: float64
full: 
n    0.682227
e    0.317773
Name: pos, dtype: float64


n    0.654591
e    0.345409
Name: neg, dtype: float64

In [7]:
#first: pos logits
#second: neg logits
#after that: direct, indirect effects for each i

hidden_dim = 768
column_names = [None]*(2*hidden_dim+2)
column_names[0] = "pos_logits"
column_names[1] = "neg_logits"
for i in range(hidden_dim):
    column_names[i*2+2] = str(i) + "_dir"
    column_names[i*2+3] = str(i) + "_indir"

logits = pd.read_csv('experiments/may7/binary_finetune.tsv', header=None, delimiter='\t', names=column_names)

In [11]:
entailment_logits = logits.copy()
display(entailment_logits)

Unnamed: 0,pos_logits,neg_logits,0_dir,0_indir,1_dir,1_indir,2_dir,2_indir,3_dir,3_indir,...,763_dir,763_indir,764_dir,764_indir,765_dir,765_indir,766_dir,766_indir,767_dir,767_indir
0,"[0.8189356923103333, 0.18106429278850555]","[0.7719004154205322, 0.22809956967830658]","[0.7719480395317078, 0.22805191576480865]","[0.8188955783843994, 0.18110442161560059]","[0.7719001770019531, 0.22809989750385284]","[0.8189359307289124, 0.18106405436992645]","[0.7719041109085083, 0.2280958741903305]","[0.8189325928688049, 0.18106745183467865]","[0.7720416784286499, 0.22795835137367249]","[0.8188166618347168, 0.18118330836296082]",...,"[0.771982729434967, 0.22801733016967773]","[0.8188663721084595, 0.18113359808921814]","[0.7719653248786926, 0.22803469002246857]","[0.8188809752464294, 0.18111899495124817]","[0.7718992829322815, 0.2281007021665573]","[0.8189366459846497, 0.18106338381767273]","[0.7719295024871826, 0.22807052731513977]","[0.8189111948013306, 0.18108879029750824]","[0.7719926238059998, 0.22800730168819427]","[0.8188579678535461, 0.18114203214645386]"
1,"[0.898227334022522, 0.10177261382341385]","[0.9275789260864258, 0.0724211260676384]","[0.9275546669960022, 0.07244531065225601]","[0.898260235786438, 0.10173973441123962]","[0.9275789260864258, 0.0724211260676384]","[0.898227334022522, 0.10177261382341385]","[0.9275751113891602, 0.07242486625909805]","[0.8982325196266174, 0.10176751762628555]","[0.927501380443573, 0.07249864935874939]","[0.8983327746391296, 0.10166722536087036]",...,"[0.9275227785110474, 0.07247716188430786]","[0.8983036279678345, 0.1016964316368103]","[0.927540123462677, 0.07245989143848419]","[0.8982800245285034, 0.10171990096569061]","[0.9275792241096497, 0.07242083549499512]","[0.8982269763946533, 0.10177303105592728]","[0.9275604486465454, 0.07243954390287399]","[0.8982524275779724, 0.1017475575208664]","[0.9275371432304382, 0.07246282696723938]","[0.8982840776443481, 0.10171589255332947]"
2,"[0.9673073291778564, 0.03269265592098236]","[0.9730486869812012, 0.026951290667057037]","[0.973042368888855, 0.02695763297379017]","[0.9673150181770325, 0.03268501162528992]","[0.9730488061904907, 0.026951231062412262]","[0.9673073291778564, 0.03269273415207863]","[0.9730470180511475, 0.026952994987368584]","[0.9673094153404236, 0.03269060701131821]","[0.9730289578437805, 0.02697104588150978]","[0.9673311710357666, 0.03266885504126549]",...,"[0.9730322360992432, 0.026967739686369896]","[0.9673271775245667, 0.032672833651304245]","[0.973039448261261, 0.02696061320602894]","[0.967318594455719, 0.0326814241707325]","[0.9730488061904907, 0.02695116586983204]","[0.9673072099685669, 0.0326928049325943]","[0.9730437397956848, 0.02695632167160511]","[0.9673134684562683, 0.03268659487366676]","[0.9730383157730103, 0.02696164697408676]","[0.9673197865486145, 0.032680172473192215]"
3,"[0.9671698212623596, 0.03283024951815605]","[0.9813395142555237, 0.018660515546798706]","[0.9813249707221985, 0.01867503672838211]","[0.9671949148178101, 0.032805077731609344]","[0.9813395142555237, 0.018660519272089005]","[0.9671698212623596, 0.03283023461699486]","[0.9813387393951416, 0.0186612568795681]","[0.9671710133552551, 0.03282895311713219]","[0.9813007116317749, 0.018699323758482933]","[0.9672369360923767, 0.03276308998465538]",...,"[0.9813162088394165, 0.018683746457099915]","[0.9672099351882935, 0.03279000148177147]","[0.9813207983970642, 0.018679175525903702]","[0.9672021269798279, 0.032797910273075104]","[0.9813397526741028, 0.018660252913832664]","[0.9671692252159119, 0.032830700278282166]","[0.9813299775123596, 0.01866994984447956]","[0.9671860933303833, 0.03281388431787491]","[0.9813106060028076, 0.018689386546611786]","[0.9672197699546814, 0.03278025984764099]"
4,"[0.5252885818481445, 0.47471141815185547]","[0.5897467732429504, 0.4102531969547272]","[0.5896796584129333, 0.41032031178474426]","[0.5253577828407288, 0.47464224696159363]","[0.5897470116615295, 0.41025298833847046]","[0.5252884030342102, 0.47471165657043457]","[0.5897458791732788, 0.4102541506290436]","[0.5252895951271057, 0.4747104346752167]","[0.5895894765853882, 0.4104105830192566]","[0.5254507660865784, 0.474549263715744]",...,"[0.5896725654602051, 0.41032734513282776]","[0.5253650546073914, 0.47463491559028625]","[0.5896754860877991, 0.41032445430755615]","[0.5253620147705078, 0.4746379256248474]","[0.5897496342658997, 0.4102504253387451]","[0.5252857208251953, 0.4747142791748047]","[0.5897166728973389, 0.41028329730033875]","[0.5253196358680725, 0.4746803641319275]","[0.5896226167678833, 0.4103774428367615]","[0.5254166126251221, 0.4745834171772003]"
5,"[0.9219512939453125, 0.07804868370294571]","[0.9240758419036865, 0.07592419534921646]","[0.9240705370903015, 0.07592941075563431]","[0.9219566583633423, 0.07804333418607712]","[0.9240758419036865, 0.07592412829399109]","[0.921951174736023, 0.07804876565933228]","[0.9240764379501343, 0.07592351734638214]","[0.9219505786895752, 0.07804936915636063]","[0.9240713715553284, 0.07592860609292984]","[0.9219558238983154, 0.07804419100284576]",...,"[0.9240836501121521, 0.0759163647890091]","[0.9219433069229126, 0.07805673778057098]","[0.9240656495094299, 0.07593433558940887]","[0.9219617247581482, 0.07803831249475479]","[0.9240754246711731, 0.07592456042766571]","[0.9219517111778259, 0.07804834842681885]","[0.9240731000900269, 0.07592695951461792]","[0.9219541549682617, 0.07804587483406067]","[0.9240595698356628, 0.07594043761491776]","[0.9219679236412048, 0.07803204655647278]"
6,"[0.9943175911903381, 0.005682376213371754]","[0.9943515062332153, 0.005648417863994837]","[0.9943509101867676, 0.005649076774716377]","[0.9943183064460754, 0.005681710783392191]","[0.9943515062332153, 0.0056484173983335495]","[0.9943175911903381, 0.005682370625436306]","[0.9943519830703735, 0.00564801087602973]","[0.9943172335624695, 0.005682780873030424]","[0.9943515062332153, 0.005648430902510881]","[0.9943175911903381, 0.005682360380887985]",...,"[0.9943535327911377, 0.005646479781717062]","[0.9943156838417053, 0.005684322211891413]","[0.9943521022796631, 0.005647922866046429]","[0.9943171143531799, 0.0056828721426427364]","[0.9943516254425049, 0.005648372229188681]","[0.9943175911903381, 0.005682419519871473]","[0.9943516254425049, 0.005648385733366013]","[0.9943175911903381, 0.005682408809661865]","[0.9943494200706482, 0.005650636740028858]","[0.9943197965621948, 0.005680142901837826]"
7,"[0.630669891834259, 0.36933013796806335]","[0.41133302450180054, 0.5886669754981995]","[0.41156473755836487, 0.5884352326393127]","[0.6304469704627991, 0.3695530295372009]","[0.4113335907459259, 0.5886663794517517]","[0.6306692957878113, 0.36933067440986633]","[0.41131892800331116, 0.5886811017990112]","[0.630683422088623, 0.36931657791137695]","[0.41188204288482666, 0.5881180167198181]","[0.6301417350769043, 0.3698582649230957]",...,"[0.4115045368671417, 0.5884954929351807]","[0.6305049061775208, 0.36949512362480164]","[0.4115867614746094, 0.5884132981300354]","[0.630425751209259, 0.3695741891860962]","[0.41132333874702454, 0.5886766314506531]","[0.6306791305541992, 0.3693207800388336]","[0.41145554184913635, 0.5885443687438965]","[0.6305519938468933, 0.3694480359554291]","[0.4118358790874481, 0.5881640911102295]","[0.630186140537262, 0.3698139190673828]"
8,"[0.9768192172050476, 0.023180777207016945]","[0.9311858415603638, 0.06881412863731384]","[0.931291401386261, 0.06870860606431961]","[0.9767819046974182, 0.023218123242259026]","[0.9311859011650085, 0.06881408393383026]","[0.9768192172050476, 0.023180799558758736]","[0.9311859011650085, 0.06881404668092728]","[0.9768192172050476, 0.023180805146694183]","[0.9314442276954651, 0.06855582445859909]","[0.9767275452613831, 0.02327238768339157]",...,"[0.9313108325004578, 0.06868915259838104]","[0.9767749905586243, 0.02322501689195633]","[0.9313068985939026, 0.0686931386590004]","[0.9767763614654541, 0.023223603144288063]","[0.931182324886322, 0.06881769001483917]","[0.9768204689025879, 0.023179518058896065]","[0.9312462210655212, 0.06875375658273697]","[0.9767978191375732, 0.023202132433652878]","[0.9313976764678955, 0.06860236078500748]","[0.9767441749572754, 0.023255839943885803]"
9,"[0.7230625152587891, 0.2769375145435333]","[0.7076705098152161, 0.29232946038246155]","[0.7077067494392395, 0.2922932803630829]","[0.7230274677276611, 0.27697259187698364]","[0.7076705694198608, 0.29232943058013916]","[0.7230623960494995, 0.2769375443458557]","[0.7076646089553833, 0.2923353612422943]","[0.7230682373046875, 0.2769318222999573]","[0.7077183127403259, 0.2922816574573517]","[0.7230162620544434, 0.2769837975502014]",...,"[0.7076385021209717, 0.2923614978790283]","[0.7230934500694275, 0.2769065201282501]","[0.7076905965805054, 0.29230931401252747]","[0.7230429649353027, 0.2769570052623749]","[0.7076689004898071, 0.29233109951019287]","[0.7230640649795532, 0.2769359350204468]","[0.7076766490936279, 0.29232335090637207]","[0.7230565547943115, 0.2769434154033661]","[0.7077353596687317, 0.2922646701335907]","[0.7229998111724854, 0.2770002484321594]"


In [23]:
for i, column in enumerate(logits.columns):
    entailment_logits[column] = logits[column].apply(lambda x: (x.split(",")[0][1:]))

In [26]:
entailment_logits = entailment_logits.astype('float')

IndexError: index 1537 is out of bounds for axis 0 with size 1537

In [None]:
print(len(entailment_logits))

In [None]:
#get average change for direct effects



#get average change for indirect effects

