# Athena Rationales Global
Large Scale Empirical Analysis 

In [1]:
from pathlib import Path
import csv
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import functools

pd.options.display.float_format = '{:.2f}'.format

In [2]:
from tokenizers import ByteLevelBPETokenizer
import torch
import importlib
from fairseq.models.transformer import TransformerModel

In [3]:
import warnings
from matplotlib import colors
import os
from rationalization import rationalize_lm, rationalize_conditional_model

In [4]:
def param_default():
    corpus_fm = 'fm'
    corpus_fm_fc = 'fm_fc'
    corpus_fm_fc_co = 'fm_fc_co'
    corpus_fm_fc_ms = 'fm_fc_ms'
    corpus_fm_fc_ms_ff = 'fm_fc_ms_ff' #<-- Scope
    
    corpora = [ corpus_fm, corpus_fm_fc, corpus_fm_fc_co, corpus_fm_fc_ms, corpus_fm_fc_ms_ff ]

    #data_path = Path('../whole-dateset/corpus/' + corpus + '/')
    data_path_raw_corpora = [ Path('../whole-dateset/corpus/' + corpus + '/raw/') for corpus in corpora ]
    tokenizer_path = Path('../tokenizer/')
    return {
        'bpe_path' : tokenizer_path / 'universal_tokenizer/universal_tokenizer/roberta_aug_spaces',
        'eval_raw': [ [data_path_raw / 'eval/input.methods.txt', data_path_raw / 'eval/output.tests.txt'] for data_path_raw in data_path_raw_corpora] ,
        'test_raw': [ [data_path_raw / 'test/input.methods.txt', data_path_raw / 'test/output.tests.txt'] for data_path_raw in data_path_raw_corpora] ,
        'train_raw': [ [data_path_raw / 'train/input.methods.txt', data_path_raw / 'train/output.tests.txt'] for data_path_raw in data_path_raw_corpora],
        'data_labels' : ['test_raw'],#['eval_raw','test_raw','train_raw'], <----- Just Test
        #'output_pandas' : data_path / 'pandas/',
        'out_processed' : '/datasets/out_processed/',
        'model_name_or_path' : 'models/checkpoint_dir_01/models/', #Model Path
        'checkpoint_file': 'checkpoint_best.pt', #Model
        'data_preprocessed':'/home/davidna/data/dummy/sequential-rationales/fairseq/fairseq/data-bin/bins/',
        'output_results' : 'results/',
        'corpora': corpora 
    }

In [5]:
params = param_default()
params['corpora']

['fm', 'fm_fc', 'fm_fc_co', 'fm_fc_ms', 'fm_fc_ms_ff']

## Rationalizations Utilities

In [6]:
rationalization = importlib.import_module("sequential-rationales.huggingface.rationalization")
rationalize = rationalization.rationalize_lm
warnings.filterwarnings("ignore")

## Universal Tokenizer

In [7]:
def lazy_decode(bpe_java):
    return bpe_java.replace(' ','').replace('Ġ',' ').replace('Ċ','\n')

In [8]:
def prettify_java(minified_java):
    "tries to undo Michele's minification. Works decently, although for loops and sets get newlines inserted, and there are no empty lines or comments"
    minified_java = minified_java.replace('{','{\n').replace('}','}\n').replace(';',';\n')
    num_indents = 0
    pretty_java = ''
    for line in minified_java.splitlines():
        if line.lstrip().startswith('}'):
            num_indents -= 1
        pretty_java += num_indents*'    '+line+'\n'
        if line.endswith('{'):
            num_indents += 1
        if line.endswith('}') and not line.lstrip().startswith('}'):
            num_indents -= 1
    return pretty_java

## Model Loading and Testing

In [9]:
#Loading a pretrain model
model = TransformerModel.from_pretrained(
  model_name_or_path = params['model_name_or_path'],
  checkpoint_file = params['checkpoint_file'],
  #data_name_or_path = params['data_preprocessed']
)

In [10]:
#Setting experiments 
#! export CUDA_VISIBLE_DEVICES="0,1"

In [11]:
## Move model to GPU if available and trigger evaluation mode
def model_activate(model = model):
  if torch.cuda.is_available():
    model.cuda()
    model.eval()
    model.model = model.models[0]
    model.device
    print("Model Activated")
  pass

## Universal Tokenizer

In [12]:
def load_tokenizer(bpe_path):
    return ByteLevelBPETokenizer(str(bpe_path)+'-vocab.json',str(bpe_path)+'-merges.txt')

In [13]:
tokenizer = load_tokenizer(params['bpe_path'])

## Data Loading and Testing

In [14]:
#export
def method_size_vector( method_vector ):
    '''Return the size of the tokens for a give method based on id
        Assuming that method_vector is an array of tokens
    '''
    input_ids = [ len(mtd) for mtd in method_vector ]
    return input_ids

In [15]:
def super_set_code():
    data = {}
    for label in params['data_labels']:
        corpora = params['corpora']
        for i,path_data in enumerate(params[ label ]):

            print(i,corpora[i],path_data[0])
            name = label + str('_input_') + corpora[i]
            df = pd.read_csv( path_data[0], sep="\n", header=None, names=[ name ] ) #reading file
            df[ label + '_bpe_' + str('_input_') + corpora[i]] = [ enc.tokens for enc in tokenizer.encode_batch( df[ name ].values ) ] #bpe
            df['method_size'+ str('_input_') + corpora[i]] = method_size_vector( df[label + '_bpe_' + str('_input_') + corpora[i]].values ) #counting tokens
            data[ name ] =  df  
        #data[-1].columns = [ label ]
    return data

In [16]:
super_data = super_set_code()

0 fm ../whole-dateset/corpus/fm/raw/test/input.methods.txt
1 fm_fc ../whole-dateset/corpus/fm_fc/raw/test/input.methods.txt
2 fm_fc_co ../whole-dateset/corpus/fm_fc_co/raw/test/input.methods.txt
3 fm_fc_ms ../whole-dateset/corpus/fm_fc_ms/raw/test/input.methods.txt
4 fm_fc_ms_ff ../whole-dateset/corpus/fm_fc_ms_ff/raw/test/input.methods.txt


In [17]:
super_data.keys()

dict_keys(['test_raw_input_fm', 'test_raw_input_fm_fc', 'test_raw_input_fm_fc_co', 'test_raw_input_fm_fc_ms', 'test_raw_input_fm_fc_ms_ff'])

In [18]:
super_data['test_raw_input_fm'].shape

(78388, 3)

In [20]:
super_data['test_raw_input_fm_fc'].shape

(78388, 3)

In [19]:
flat_result = pd.concat(super_data, axis=1)

In [53]:
flat_result.head()

Unnamed: 0_level_0,test_raw_input_fm,test_raw_input_fm,test_raw_input_fm,test_raw_input_fm_fc,test_raw_input_fm_fc,test_raw_input_fm_fc,test_raw_input_fm_fc_co,test_raw_input_fm_fc_co,test_raw_input_fm_fc_co,test_raw_input_fm_fc_ms,test_raw_input_fm_fc_ms,test_raw_input_fm_fc_ms,test_raw_input_fm_fc_ms_ff,test_raw_input_fm_fc_ms_ff,test_raw_input_fm_fc_ms_ff
Unnamed: 0_level_1,test_raw_input_fm,test_raw_bpe__input_fm,method_size_input_fm,test_raw_input_fm_fc,test_raw_bpe__input_fm_fc,method_size_input_fm_fc,test_raw_input_fm_fc_co,test_raw_bpe__input_fm_fc_co,method_size_input_fm_fc_co,test_raw_input_fm_fc_ms,test_raw_bpe__input_fm_fc_ms,method_size_input_fm_fc_ms,test_raw_input_fm_fc_ms_ff,test_raw_bpe__input_fm_fc_ms_ff,method_size_input_fm_fc_ms_ff
0,public static Date yearStart() { final Gregori...,"[public, Ġstatic, ĠDate, Ġyear, Start, (), Ġ{,...",42,DateUtils { public static Date yearStart() { f...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",47,DateUtils { public static Date yearStart() { f...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",47,DateUtils { public static Date yearStart() { f...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",227,DateUtils { public static Date yearStart() { f...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",227
1,public static Date yearEnd() { final Gregorian...,"[public, Ġstatic, ĠDate, Ġyear, End, (), Ġ{, Ġ...",65,DateUtils { public static Date yearEnd() { fin...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",70,DateUtils { public static Date yearEnd() { fin...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",70,DateUtils { public static Date yearEnd() { fin...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",250,DateUtils { public static Date yearEnd() { fin...,"[Date, Ut, ils, Ġ{, Ġpublic, Ġstatic, ĠDate, Ġ...",250
2,public void validate(TokenBinding clientDataTo...,"[public, Ġvoid, Ġvalidate, (, Token, B, inding...",170,TokenBindingValidator { public void validate(T...,"[Token, B, inding, Valid, ator, Ġ{, Ġpublic, Ġ...",177,TokenBindingValidator { public void validate(T...,"[Token, B, inding, Valid, ator, Ġ{, Ġpublic, Ġ...",177,TokenBindingValidator { public void validate(T...,"[Token, B, inding, Valid, ator, Ġ{, Ġpublic, Ġ...",197,TokenBindingValidator { public void validate(T...,"[Token, B, inding, Valid, ator, Ġ{, Ġpublic, Ġ...",197
3,public static int getUnsignedShort(ByteBuffer ...,"[public, Ġstatic, Ġint, Ġget, Un, signed, Shor...",29,UnsignedNumberUtil { public static int getUnsi...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",36,UnsignedNumberUtil { public static int getUnsi...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",43,UnsignedNumberUtil { public static int getUnsi...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",191,UnsignedNumberUtil { public static int getUnsi...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",239
4,public static boolean isWithinUnsignedLong(Big...,"[public, Ġstatic, Ġboolean, Ġis, Within, Un, s...",41,UnsignedNumberUtil { public static boolean isW...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",48,UnsignedNumberUtil { public static boolean isW...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",55,UnsignedNumberUtil { public static boolean isW...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",203,UnsignedNumberUtil { public static boolean isW...,"[Un, signed, Number, Ut, il, Ġ{, Ġpublic, Ġsta...",251


In [None]:
df_indexes = pd.DataFrame()

In [49]:
df_indexes[['fm','fm_fc','fm_fc_co','fm_fc_ms','fm_fc_ms_ff']] = flat_result[[
    ('test_raw_input_fm','method_size_input_fm'),
    ('test_raw_input_fm_fc', 'method_size_input_fm_fc'),
    ('test_raw_input_fm_fc_co', 'method_size_input_fm_fc_co'),
    ('test_raw_input_fm_fc_ms', 'method_size_input_fm_fc_ms'),
    ('test_raw_input_fm_fc_ms_ff', 'method_size_input_fm_fc_ms_ff')
    ]].copy()

In [51]:
df_indexes['fm_fc__fm'] =  df_indexes['fm_fc'].values - df_indexes['fm']

In [52]:
df_indexes

Unnamed: 0,fm,fm_fc,fm_fc_co,fm_fc_ms,fm_fc_ms_ff,fm_fc__fm
0,42,47,47,227,227,5
1,65,70,70,250,250,5
2,170,177,177,197,197,7
3,29,36,43,191,239,7
4,41,48,55,203,251,7
...,...,...,...,...,...,...
78383,22,39,43,80,80,17
78384,62,70,89,281,281,8
78385,23,35,77,605,605,12
78386,21,27,41,329,329,6


In [66]:
df_index_sampled = df_indexes.sample(
    n= 1000,
    random_state=3
).copy()

In [67]:
df_index_sampled

Unnamed: 0,fm,fm_fc,fm_fc_co,fm_fc_ms,fm_fc_ms_ff,fm_fc__fm
17231,195,212,221,277,277,17
23059,273,278,278,423,423,5
24609,16,32,112,317,317,16
30509,92,99,114,114,114,7
62188,144,148,169,204,204,4
...,...,...,...,...,...,...
60348,1805,1818,1911,2097,2120,13
20066,1019,1029,1056,1080,1154,10
50521,361,386,446,552,552,25
76473,374,387,387,2528,2563,13


In [None]:
############################################

In [64]:
#Loading Code Generation
df_generated_input = pd.read_json( params['output_results'] + 'generation_01.json' )

In [61]:
df_indexes['fm_fc__fm'][0]

5

In [69]:
len(df_generated_input['source_sampling'][0])

163

Bad pipe message: %s [b'Q\x91Y\xbe\xe4\xb9u.E?\xd1\x1b\x98H\x16=[\xa9 \xab;\x8c\x18\xf7:Dur`\x99}jA\xd2\xd9\xa4\x0b\xe9Y\x7f\x84\xeb\x17\x87\xd3;\xeb\xf9\x01yk\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 k\r\xc3\x86\xe0\x8f\xc55\x8d\x8d\x00\xb55dYh\xb5+%\xac\x92fT\xd3\xf3{\xf6\x80\x92L\xd0']
Bad pipe message: %s [b'\x0b\xeeJ\xf2\xf7U\xb6"\xf7b\xed\xf6\xbf\xf5b\x14G\xbd\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc']
Bad pipe message: %s [b"\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\x

In [63]:
#Example of ContextWindow
temp_class = df_generated_input['source_sampling'][0][:df_indexes['fm_fc__fm'][0] ] #classes
model.decode( temp_class )

'Thread s Process ing Item'

In [65]:
print('df readit')
df_generated_input.head()

df readit


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,21,22,23,24,25,26,27,28,29,source_sampling
0,"[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 44840, 26170,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 44840, 26170,...","[1039, 34603, 1640, 10162, 5457, 44840, 26170,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 44840, 26170,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...",...,"[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[1039, 34603, 1640, 10162, 5457, 36993, 45621,...","[43542, 48455, 14269, 11056, 104, 44306, 4741,..."
1,"[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...",...,"[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[1039, 34603, 285, 13842, 197, 44514, 43048, 6...","[47181, 49187, 47599, 39962, 33177, 282, 868, ..."
2,"[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...",...,"[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1039, 34603, 285, 13842, 1296, 14181, 3750, 1...","[1121, 4771, 2068, 139, 5320, 20636, 25522, 28..."
3,"[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 197, 22011, 1090, 20...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...",...,"[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[1039, 34603, 285, 13842, 197, 22011, 1090, 20...","[1039, 34603, 285, 13842, 1296, 22011, 1090, 2...","[104, 571, 506, 49707, 25522, 285, 9527, 41552..."
4,"[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...",...,"[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[1039, 34603, 285, 13842, 1296, 21527, 45788, ...","[45408, 26579, 48557, 40025, 2630, 14269, 4364..."


In [14]:
#tst decoding
decoded = model.decode(df_generated_input['0'][0])
decoded

'@ Test Ġpublic Ġvoid Ġtest Process Event () Ġ{ Ġ}'

In [15]:
prettify_java( lazy_decode( decoded ) )

'@Test public void testProcessEvent() {\n }\n'

## Running Rationales

In [16]:
#Statistics
np.mean( [len(i) for i in df_generated_input['0'].values] )

74.853

In [17]:
#TODO Run the distribution of each experiment. The mean value of tokens or size for each experiment. 
np.mean( [len(i) for i in df_generated_input['source_sampling'].values] )

353.018

In [17]:
len(df_generated_input['0'].values[2])

40

In [18]:
MAX_TOKEN_SIZE = 100 #Hardocoded!!

In [19]:
#If the model is not fine-tuned or compatible, it will rise an error
#Bear in mind that Athena is a Translation model (not a language one)
#This function works for one tensor of source token and one tensor of target tokens
def rationalize_model(t_source_tokens, t_target_tokens, model, verbose=True):
    all_source_rationales, all_target_rationales, log = rationalize_conditional_model(
        model = model, 
        source_tokens = t_source_tokens[:MAX_TOKEN_SIZE],
        target_tokens = t_target_tokens, #[:MAX_TOKEN_SIZE], 
        verbose=verbose,
        max_steps=1024 #Max number of steps for greedy rationalization
    )
    return all_source_rationales, all_target_rationales, log 

In [20]:
#tst
def tst_rationalize_model():
    gc.collect()
    torch.cuda.empty_cache() #Cleaning Cache

    t_dict_generated_input = { exp : [ torch.tensor(s).to(model.device) for 
                s in df_generated_input[exp].values ] for exp in df_generated_input.columns }

    rationalize_model(  
        t_source_tokens =  t_dict_generated_input['source_sampling'][0],
        t_target_tokens =  t_dict_generated_input['0'][0],
        model = model 
    )
    pass

#tst_rationalize_model()

In [78]:
def run_multiple_rational(arr_source_tokens, arr_target_tokens, model, verbose=True):
    arr_log = []
    for index,val in enumerate( arr_source_tokens ):
        _, _, log = rationalize_model(
            t_source_tokens = val, 
            t_target_tokens = arr_target_tokens[index], 
            model = model,
            verbose = verbose )
        arr_log.append(log)
    arr_code_rationales = [ log['rationalizations'] for log in arr_log ]
    arr_from_sentence = [ list(np.full( len(val), arr_i )) for arr_i, val in enumerate(arr_code_rationales)]
    #arr_from_sentence = arr_code_rationales

    arr_code_rationales = sum( arr_code_rationales, [] )
    arr_from_sentence = sum( arr_from_sentence, [] )
    
    return arr_code_rationales, arr_from_sentence

In [87]:
#tst
def tst_run_multiple_rationa():
    gc.collect()
    torch.cuda.empty_cache() #Cleaning Cache

    t_dict_generated_input = { exp : [ torch.tensor(s).to(model.device) for 
                s in df_generated_input[exp].values ] for exp in df_generated_input.columns }
    
    arr_rations, from_seq = run_multiple_rational(
        arr_source_tokens =  t_dict_generated_input['source_sampling'][:2], #With 2 Sequences  
        arr_target_tokens =  t_dict_generated_input['0'][:2], 
        model = model,
        verbose = False
        )
    return arr_rations, from_seq
#tst_arr_rations, tst_from_seq = tst_run_multiple_rationa()

In [23]:
import gc

In [89]:
def pandas_rationales( arr_code_rationales, arr_from_sentence ):
    #Creating pandas_1 {p_rationale}
    rational = lambda list_log,typeset: [ (dict_tok['added_token_text'],round(dict_tok['true_token_prob'],6)) for dict_tok in list_log if dict_tok['from']==typeset]
    log_from = lambda log_row,typeset: [(log_dict['added_token_text'],log_dict['true_token_prob']) for log_dict in log_row if log_dict['from']==typeset] #Typeset

    log_position = lambda log_row,typeset: [log_dict['added_token_position'] for log_dict in log_row if log_dict['from']==typeset] #Position of the Rationale
    log_prediction = lambda log_row,typeset: [log_dict['true_token_prob'] for log_dict in log_row if log_dict['from']==typeset] #Rationale Prob

    p_rationale = pd.DataFrame()

    p_rationale['goal_token'] = [dict_token['goal_word'] for dict_token in arr_code_rationales]
    p_rationale['from_seq_id'] = arr_from_sentence

    p_rationale['typesets_tgt'] = [ log_from(log_row,'target') for log_row in [dict_token['log'] for dict_token in arr_code_rationales]]
    p_rationale['typesets_src'] = [ log_from(log_row,'source') for log_row in [dict_token['log'] for dict_token in arr_code_rationales]]


    p_rationale['rationale_pos_tgt'] = [ log_position(log_row,'target') for log_row in [dict_token['log'] for dict_token in arr_code_rationales]]
    p_rationale['rationale_pos_src'] = [ log_position(log_row,'source') for log_row in [dict_token['log'] for dict_token in arr_code_rationales]]

    p_rationale['rationale_prob_tgt'] = [ log_prediction(log_row,'target') for log_row in [dict_token['log'] for dict_token in arr_code_rationales]]
    p_rationale['rationale_prob_src'] = [ log_prediction(log_row,'source') for log_row in [dict_token['log'] for dict_token in arr_code_rationales]]

    return p_rationale

In [96]:
#Running Rationalization
def run_code_rational( 
        df_generated_input,
        tensor_size, #Control the size of the experiment
        experiment = '5',
        batch_size = 100, 
        model = model, 
        verbose = True 
    ):

    arr_rationals = []
    arr_from_seq = []

    for i in range( 0 , tensor_size , batch_size ):
        model_activate(model = model)
        print('************************' + str(i) + '************************')
        t_generated_input = df_generated_input[ experiment ].values[i:i+batch_size]
        t_source_sampling = df_generated_input['source_sampling'].values[i:i+batch_size]

        t_generated_input = [ torch.tensor(s).to(model.device) for s in t_generated_input]
        t_source_sampling = [ torch.tensor(s).to(model.device) for s in t_source_sampling]

        
        t_arr_rationals, t_arr_from_seq = run_multiple_rational(
            arr_source_tokens =  t_source_sampling, 
            arr_target_tokens =  t_generated_input, 
            model = model,
            verbose = verbose
        )

        arr_rationals = arr_rationals + t_arr_rationals
        arr_from_seq = arr_from_seq + t_arr_from_seq

        gc.collect()
        torch.cuda.empty_cache() #Cleaning Cache

    #keys_tensor = list( dict_generated_input.keys() )
    #keys_tensor = keys_tensor[:1] #HardCoded Ratios
    #dict_arr_rations = { key : for key in keys_tensor}
    #torch.cuda.empty_cache() #Cleaning Cache
    print("Experiment Finished: " + experiment)
    return pandas_rationales( arr_rationals, arr_from_seq )

In [97]:
gc.collect()
torch.cuda.empty_cache()

In [98]:
torch.cuda.is_available()

True

In [103]:
#tst
def tst_run_code_rational(exp):
    gc.collect()
    torch.cuda.empty_cache()
    tensor_n = 1000 #df_generated_input.shape[0]
    EXP = exp
    BATCH = 100
    test_arr_rationals = run_code_rational( 
            df_generated_input = df_generated_input.sample( n = tensor_n, replace = False, random_state=2),
            tensor_size = tensor_n,
            experiment = EXP,
            batch_size = BATCH, 
            model = model, 
            verbose = False 
        )
    #Saving process
    print('Saving process')
    test_arr_rationals.to_json( params['output_results'] + 'rationales_[t_1000]_[max_100]_' + EXP )
    return test_arr_rationals
#df_test_run = tst_run_code_rational()

Model Activated
************************0************************
Experiment Finished: 0
Saving process


In [None]:
for i in df_generated_input.columns[:-1]:
    df_test_run = tst_run_code_rational(i)

In [102]:
df_test_run.head(1)

Unnamed: 0,goal_token,from_seq_id,typesets_tgt,typesets_src,rationale_pos_tgt,rationale_pos_src,rationale_prob_tgt,rationale_prob_src
0,Test,0,"[(@, 0.9980469942092896)]",[],[0],[],[0.9980469942092896],[]


In [28]:
#Running all Experiments
def exp_run_all_rationales():
    dict_arr_rations = { key : run_code_rational(
        df_generated_input = df_generated_input,
        experiment = key,
        batch_size = 10, 
        model = model, 
        verbose = False 
    ) for key in df_generated_input.columns[:-1] }
    return dict_arr_rations

In [None]:
#arr_df_rationale = [pandas_rationales(dict_arr_rations[key]) for key in dict_arr_rations.keys()]