In [1]:

import logging
from pathlib import Path

In [2]:
from pathlib import Path
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import sys
import torch


pd.options.display.float_format = '{:.3f}'.format

In [3]:
#| export
logging.basicConfig(
    filename="../datax/logs/bart_evaluation_log.txt",
    filemode='a',
    format='%(asctime)s : %(levelname)s : %(message)s', 
    level=logging.INFO
    )

In [4]:
def param_default():
    corpus = 'fm_fc_ms_ff' #<-- Scope
    data_path = Path('/workspaces/code-rationales/semeru-datasets/athena_test/' + corpus + '/')
    data_path_raw = Path(data_path/ 'raw')
    return {
        'bpe_path' : '/workspaces/code-rationales/scripts/tokenizer/universal_tokenizer/roberta_aug_spaces',
        'eval_raw': [data_path_raw / 'eval/input.methods.txt',
                        data_path_raw / 'eval/output.tests.txt'],
        'test_raw': [data_path_raw / 'test/input.methods.txt', 
                        data_path_raw / 'test/output.tests.txt'],
        'train_raw': [data_path_raw / 'train/input.methods.txt', 
                        data_path_raw / 'train/output.tests.txt'],
        'data_labels' : ['test_raw'],#['eval_raw','test_raw','train_raw'], <----- Just Test
        'super_data_checkpoint' : data_path / 'pandas',
        'out_processed' : '/datasets/out_processed/',
        'model_name_or_path' : '/workspaces/code-rationales/data/bart-fairseq/checkpoint_dir_athena_ms/models/', #Model Path
        'checkpoint_file': 'checkpoint_best.pt', #Model
        'output_sample' : '/workspaces/code-rationales/data/sampling/bart/',
        'corpus': corpus
    }

In [5]:
#sys.path.clear()

In [6]:
from fairseq.models.transformer import TransformerModel
from tokenizers import ByteLevelBPETokenizer


In [7]:
def load_tokenizer(bpe_path):
    return ByteLevelBPETokenizer(str(bpe_path)+'-vocab.json',str(bpe_path)+'-merges.txt')

In [8]:
def lazy_decode(bpe_java):
    return bpe_java.replace(' ','').replace('Ġ',' ').replace('Ċ','\n')

In [9]:
sys_params = param_default()

### Load Model

*NOTE:*  Load the same checkpoint or model as the input data was created

In [10]:
#Loading a pretrain model
model = TransformerModel.from_pretrained(
  model_name_or_path = sys_params['model_name_or_path'],
  checkpoint_file = sys_params['checkpoint_file'],
)

In [11]:
## Move model to GPU if available and trigger evaluation mode
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [12]:
## If GPU busy the use CPU
#device = 'cpu'

In [13]:
model = model.to( device ) #WARNING, Verify the device before assigning to memory
model.eval()

GeneratorHubInterface(
  (models): ModuleList(
    (0): BARTModel(
      (encoder): TransformerEncoderBase(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(50348, 512, padding_idx=1)
        (embed_positions): SinusoidalPositionalEmbedding()
        (layers): ModuleList(
          (0): TransformerEncoderLayerBase(
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (dropout_module): FairseqDropout()
            (activation_dropout_module): FairseqDropout()
            (fc1): Linear(in_features=512,

### Load Tokenizer

In [14]:
tokenizer = load_tokenizer(sys_params['bpe_path'])

In [15]:
def prettify_java(minified_java):
    "tries to undo Michele's minification. Works decently, although for loops and sets get newlines inserted, and there are no empty lines or comments"
    minified_java = minified_java.replace('{','{\n').replace('}','}\n').replace(';',';\n')
    num_indents = 0
    pretty_java = ''
    for line in minified_java.splitlines():
        if line.lstrip().startswith('}'):
            num_indents -= 1
        pretty_java += num_indents*'    '+line+'\n'
        if line.endswith('{'):
            num_indents += 1
        if line.endswith('}') and not line.lstrip().startswith('}'):
            num_indents -= 1
    return pretty_java

## Data loading

In [16]:
corpus = sys_params['corpus']
super_data = pd.read_json(sys_params['output_sample']+corpus+'_generated.tests.json')

In [17]:
super_data.head()

Unnamed: 0,index,input,input_bpe,input_method_size,output,output_bpe,output_method_size,0,1,2,...,21,22,23,24,25,26,27,28,29,input_is
0,58561,JavaTypeFilter extends RecipeFilter { @Overrid...,"[Java, Type, Filter, Ġextends, ĠRecipe, Filter...",91,@Test public void testApply_convertsArrayOfInt...,"[@, Test, Ġpublic, Ġvoid, Ġtest, Apply, _, con...",65,"[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...",...,"[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[1039, 34603, 285, 13842, 1296, 47456, 43048, ...","[32379, 40118, 47625, 14269, 38945, 47625, 255..."
1,46602,SafeguardLimitValidator { public void validate...,"[S, af, egu, ard, Limit, Valid, ator, Ġ{, Ġpub...",172,@Test(expectedExceptions = InvalidPropertyExce...,"[@, Test, (, expected, Ex, ceptions, Ġ=, ĠInva...",100,"[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...",...,"[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[1039, 34603, 285, 13842, 1296, 20320, 32890, ...","[104, 2001, 33870, 1120, 47593, 48911, 2630, 2..."
2,69877,DeviceNetworkAddressCleanupService { public vo...,"[Device, Network, Address, Clean, up, Service,...",215,@Test public void noDevicesAreCleanedWhenTheNe...,"[@, Test, Ġpublic, Ġvoid, Ġno, Dev, ices, Are,...",126,"[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...",...,"[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 285, 13842, 1296, 18938, 34965, ...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 285, 13842, 1296, 18938, 34965, ...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[1039, 34603, 1640, 10162, 5457, 28283, 40534,...","[47580, 40283, 46486, 40827, 658, 32537, 25522..."
3,8781,FileEntityProvider implements EntityProvider<F...,"[File, Entity, Provider, Ġimplements, ĠEntity,...",245,@Test public void isNotWritableForTypeOtherTha...,"[@, Test, Ġpublic, Ġvoid, Ġis, Not, Writ, able...",40,"[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...",...,"[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[1039, 34603, 285, 13842, 1296, 6209, 45714, 8...","[9966, 49448, 48903, 36987, 46718, 48903, 4155..."
4,65172,DefaultUserAuthService implements UserAuthServ...,"[Default, User, Auth, Service, Ġimplements, ĠU...",86,@Test public void authenticatedUserIsAdminRetu...,"[@, Test, Ġpublic, Ġvoid, Ġauthenticated, User...",61,"[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...",...,"[1039, 34603, 285, 13842, 44723, 44518, 6209, ...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[1039, 34603, 285, 13842, 44723, 44518, 6209, ...","[1039, 34603, 285, 13842, 1296, 48151, 5554, 4...","[48398, 44518, 44298, 32537, 36987, 27913, 442..."


In [18]:
decoded = model.decode(super_data['0'].values[0])
decoded

'@ Test Ġpublic Ġvoid Ġtest Apply () Ġthrows ĠException Ġ{ ĠJava Type Filter Ġfilter Ġ= Ġnew ĠJava Type Filter ( Cook book Ut ils . get Cook book ()); ĠObject Ġvalue Ġ= Ġfilter . apply (" abc ", Ġ" abc "); Ġassert Equ als (" abc ", Ġvalue ); Ġassert Equ als (" abc ", Ġvalue ); Ġ}'

In [19]:
prettify_java( lazy_decode( decoded ) )

'@Test public void testApply() throws Exception {\n     JavaTypeFilter filter = new JavaTypeFilter(CookbookUtils.getCookbook());\n     Object value = filter.apply("abc", "abc");\n     assertEquals("abc", value);\n     assertEquals("abc", value);\n }\n'

In [20]:
arr_prettify_generated = [prettify_java(lazy_decode(model.decode(input))) for input in super_data['0'].values]

In [21]:
arr_prettify_generated

['@Test public void testApply() throws Exception {\n     JavaTypeFilter filter = new JavaTypeFilter(CookbookUtils.getCookbook());\n     Object value = filter.apply("abc", "abc");\n     assertEquals("abc", value);\n     assertEquals("abc", value);\n }\n',
 '@Test public void testValidate() {\n     when(propertyHolder.getLong("safeguard.responseFIdecoder.OFFlimit")).thenReturn(123L);\n     when(propertyHolder.getLong("safeguard.responseFIdecoder.OFFlimit")).thenReturn(123L);\n     underTest.validate();\n     verify(propertyHolder).getLong("safeguard.responseFIdecoder.OFFlimit");\n }\n',
 '@Test(expected = UnknownHostException.class) public void testClearDuplicateAddresses() throws UnknownHostException {\n     service.clearDuplicateAddresses("127.0.0.1", "127.0.0.1");\n }\n',
 '@Test public void testIsWriteable() {\n     assertTrue(fileEntityProvider.isWriteable(File.class, null, null, null));\n     assertTrue(fileEntityProvider.isWriteable(File.class, null, null, null, null));\n     asse

## Decoding each sample

In [22]:
bart_df = pd.DataFrame()

In [23]:
## In this case the ground truth is the ouput
bart_df['ground_truth'] = super_data['output'].values

In [24]:
SAMPLES = 30

In [25]:
for i in range(SAMPLES):
    bart_df['outcome_'+str(i)] = [prettify_java(lazy_decode(model.decode(input))) for input in super_data[str(i)].values]

In [26]:
bart_df

Unnamed: 0,ground_truth,outcome_0,outcome_1,outcome_2,outcome_3,outcome_4,outcome_5,outcome_6,outcome_7,outcome_8,...,outcome_20,outcome_21,outcome_22,outcome_23,outcome_24,outcome_25,outcome_26,outcome_27,outcome_28,outcome_29
0,@Test public void testApply_convertsArrayOfInt...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testJavaType() throws Except...,...,@Test public void testJavaType() throws Except...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...,@Test public void testApply() throws Exception...
1,@Test(expectedExceptions = InvalidPropertyExce...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,...,@Test public void testValidate() {\n Safeg...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n when(...,@Test public void testValidate() {\n Safeg...,@Test public void testValidate() {\n Safeg...,@Test public void testValidate() {\n Safeg...,@Test public void testValidate() {\n testV...
2,@Test public void noDevicesAreCleanedWhenTheNe...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,@Test(expectedExceptions = UnknownHostExceptio...,@Test(expectedExceptions = UnknownHostExceptio...,@Test(expectedExceptions = UnknownHostExceptio...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,...,@Test(expectedExceptions = UnknownHostExceptio...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...,@Test public void testClearDuplicateAddresses(...,@Test(expected = UnknownHostException.class) p...,@Test public void testClearDuplicateAddresses(...,@Test(expected = UnknownHostException.class) p...,@Test(expected = UnknownHostException.class) p...
3,@Test public void isNotWritableForTypeOtherTha...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...,@Test public void testIsWriteable() {\n as...
4,@Test public void authenticatedUserIsAdminRetu...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void authenticatedUserIsAdmin() {...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void testAuthenticatedUserIsAdmin...,@Test public void authenticatedUserIsAdmin() {...,@Test public void testAuthenticatedUserIsAdmin...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() throws Except...,@Test public void testIsRandomPort() throws Ex...,@Test public void testIsRandomPort() throws Ex...,@Test public void testIsRandomPort() throws Ex...,@Test public void testIsRandomPort() throws Ex...,@Test public void testIsRandomPort() throws Ex...,@Test public void testIsRandomPort() throws Ex...,@Test public void isRandomPort() throws Except...,@Test public void isRandomPort() {\n }\n
96,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...,@Test public void testGetCredentials() throws ...
97,@Test public void itThrowsExnOnNotFoundDefault...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...,@Test(expected = IllegalArgumentException.clas...
98,@Test public void testDecode2() throws Excepti...,@Test public void testDecodeEmptyString() {\n ...,@Test public void decodeEmptyString() throws E...,@Test public void decodeEmptyString() {\n ...,@Test public void testDecodeEmptyString() thro...,@Test public void testDecodeEmpty() {\n as...,@Test public void emptyString() {\n assert...,@Test public void testDecodeEmptyString() {\n ...,@Test public void testDecode() throws Exceptio...,@Test public void empty() {\n assertEquals...,...,@Test(expected = IllegalArgumentException.clas...,@Test public void testDecode() throws Exceptio...,@Test public void empty() {\n assertEquals...,@Test public void testDecodeString() throws Ex...,@Test(expected = NullPointerException.class) p...,@Test public void testDecodeEmptyString() thro...,@Test public void testDecode() throws Exceptio...,@Test public void decode() throws Exception {\...,@Test public void testDecodeString() throws Ex...,@Test public void testDecode() {\n assertE...


## Levenshtein evaluation

In [27]:
import textdistance

In [28]:
levenshtein_similarity = textdistance.levenshtein


In [29]:
SIZE_SAMPLING = 146

In [30]:
def reduce_sequence_size():
    bart_df['ground_truth'] = bart_df['ground_truth'].values[:SIZE_SAMPLING]
    for i in range(SAMPLES):
        bart_df['outcome_'+str(i)] = bart_df['outcome_'+str(i)].values[:SIZE_SAMPLING]


Load from previous checkpoint if it exists

In [31]:
path= '/workspaces/code-rationales/semeru-datasets/athena_test/' +corpus +'/pandas/calc_lev_30_100_samples.parquet'
#bart_lev_df = pd.read_parquet(path)

In [32]:
## !!!! ALERT TIME COMSUMMING!!! 
# Use reduce_sequence_size to reeduce sequence size and computation time or load from SAVING!!
bart_lev_df = pd.DataFrame()
for i in range(SAMPLES):
    bart_lev_df['lev_'+str(i)] = [levenshtein_similarity.normalized_similarity(x["ground_truth"].strip(), x["outcome_"+str(i)].strip() ) for _,x in bart_df.iterrows()]

In [33]:
bart_lev_df

Unnamed: 0,lev_0,lev_1,lev_2,lev_3,lev_4,lev_5,lev_6,lev_7,lev_8,lev_9,...,lev_20,lev_21,lev_22,lev_23,lev_24,lev_25,lev_26,lev_27,lev_28,lev_29
0,0.288,0.301,0.301,0.301,0.291,0.297,0.283,0.314,0.279,0.275,...,0.251,0.297,0.284,0.319,0.301,0.289,0.293,0.297,0.306,0.314
1,0.209,0.275,0.169,0.237,0.313,0.241,0.326,0.250,0.326,0.231,...,0.256,0.285,0.288,0.282,0.285,0.288,0.215,0.218,0.215,0.165
2,0.212,0.210,0.221,0.225,0.223,0.234,0.212,0.212,0.208,0.212,...,0.221,0.212,0.221,0.212,0.208,0.219,0.205,0.260,0.205,0.221
3,0.186,0.188,0.206,0.208,0.210,0.233,0.235,0.238,0.269,0.272,...,0.415,0.514,0.545,0.529,0.568,0.662,0.623,0.689,0.649,0.438
4,0.405,0.519,0.540,0.401,0.502,0.527,0.523,0.418,0.511,0.532,...,0.430,0.426,0.376,0.380,0.350,0.409,0.350,0.405,0.422,0.401
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.260,0.379,0.079,0.079,0.079,0.079,0.079,0.079,0.079,0.079,...,0.079,0.079,0.078,0.078,0.078,0.078,0.078,0.078,0.981,0.660
96,0.473,0.473,0.476,0.490,0.503,0.473,0.374,0.480,0.490,0.476,...,0.463,0.476,0.469,0.473,0.459,0.435,0.490,0.473,0.422,0.340
97,0.152,0.163,0.173,0.184,0.147,0.147,0.200,0.132,0.147,0.145,...,0.244,0.284,0.238,0.284,0.273,0.272,0.278,0.268,0.258,0.258
98,0.266,0.262,0.247,0.243,0.266,0.240,0.243,0.316,0.240,0.266,...,0.202,0.281,0.198,0.262,0.186,0.247,0.281,0.262,0.262,0.232


## Checkpoint saving 

In [34]:
bart_lev_df.to_parquet(path)

In [35]:
bart_lev_df.describe()

Unnamed: 0,lev_0,lev_1,lev_2,lev_3,lev_4,lev_5,lev_6,lev_7,lev_8,lev_9,...,lev_20,lev_21,lev_22,lev_23,lev_24,lev_25,lev_26,lev_27,lev_28,lev_29
count,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,...,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0
mean,0.301,0.305,0.303,0.302,0.306,0.299,0.304,0.305,0.3,0.3,...,0.301,0.305,0.307,0.308,0.306,0.306,0.302,0.305,0.317,0.289
std,0.132,0.131,0.13,0.134,0.136,0.136,0.134,0.141,0.135,0.139,...,0.134,0.135,0.139,0.143,0.141,0.14,0.139,0.146,0.159,0.151
min,0.109,0.109,0.079,0.079,0.079,0.079,0.079,0.079,0.079,0.079,...,0.079,0.079,0.078,0.078,0.078,0.078,0.078,0.078,0.032,0.016
25%,0.215,0.215,0.225,0.217,0.212,0.204,0.212,0.209,0.207,0.207,...,0.204,0.217,0.203,0.195,0.204,0.207,0.203,0.2,0.204,0.183
50%,0.268,0.277,0.274,0.274,0.27,0.271,0.282,0.276,0.271,0.271,...,0.27,0.279,0.275,0.285,0.28,0.281,0.28,0.266,0.294,0.257
75%,0.361,0.369,0.359,0.359,0.366,0.353,0.363,0.365,0.357,0.343,...,0.381,0.382,0.393,0.397,0.402,0.399,0.395,0.403,0.387,0.383
max,0.879,0.854,0.816,0.936,0.741,0.772,0.842,0.792,0.87,0.793,...,0.898,0.807,0.868,0.784,0.844,0.782,0.784,0.821,0.981,0.811


In [36]:
bart_avg_df = pd.DataFrame()
bart_avg_df['levenshtein'] = bart_lev_df.mean().values

In [37]:
bart_avg_df

Unnamed: 0,levenshtein
0,0.301
1,0.305
2,0.303
3,0.302
4,0.306
5,0.299
6,0.304
7,0.305
8,0.3
9,0.3


## Bootstrapping

In [38]:
n_bootstraps = 30

In [39]:

#| export
import numpy as np
from statistics import NormalDist

In [40]:
#| export
def bootstrapping( col, np_func, size, flag_clean_nan = False ):
    """
    @size: number of bootstrapping samples
    @np_funct: numpy function for reducing the samples (e.g., median, mean, max)
    @flag_clean_nan: flag to eliminate Nan values in the np tensor
    """
    np_data = col.values
    col_name = col.name
    #Cleaning NaNs
    if flag_clean_nan:
        np_data = np_data[ np.logical_not( np.isnan(np_data) ) ] 
    
    #Creating the boostrap replicates as long as the original data size
    #This strategy might work as imputation 
    bootstrap_repl = [ np_func( np.random.choice( np_data, size=len(np_data) ) ) for i in range( size ) ]
    
    logging.info("Empirical Estimate {}: {}".format(col_name, str(np_func( np_data ))) ) #Empirical Mean,Median,Max, etc
    logging.info("Bootstrapped Estimate {}: {} ".format(col_name, str( np_func( bootstrap_repl ) ) )) #Bootstrapped Mean,Median,Max, etc
    
    return np.array( bootstrap_repl )

In [41]:
def confidence_intervals_large_samples(data, confidence=0.95):
    """
    @confidence: confidence interval 
    @return: tuple (lowerbound, uperbound, h-value)
    """
    dist = NormalDist.from_samples( data )
    z = NormalDist().inv_cdf((1 + confidence) / 2.)
    h = dist.stdev * z / ((len(data) - 1) ** .5)
    return dist.mean - h, dist.mean + h, h

In [42]:
def standard_error(bootstrapped_data):
    return np.std( bootstrapped_data )

In [43]:
lev_median_np = bootstrapping( bart_avg_df.levenshtein, np_func=np.median, size=500, flag_clean_nan = False ) #Bootstrapped Complexity


In [44]:
lev_mean_np = bootstrapping( bart_avg_df.levenshtein, np_func=np.mean, size=500, flag_clean_nan = False ) #Bootstrapped Complexity


In [45]:
#Bootrapped Estimates
np.median( lev_median_np ) , np.mean( lev_mean_np )

(0.30437249570078906, 0.30376048327185173)

In [46]:
standard_error(lev_median_np), standard_error(lev_mean_np)


(0.0008392288400799651, 0.0007720482620957034)

In [47]:
test_confidence_lev_median = confidence_intervals_large_samples(data = lev_median_np, confidence=0.95)


In [48]:
test_confidence_lev_median

(0.30411156422816554, 0.3042589796580317, 7.370771493306898e-05)

## 2.0 Calculating BLUE and codeBLEU

In [49]:
## Params for codebleu: alpha, beta, gamma, theta
params='0.25,0.25,0.25,0.25'
lang= 'python'

In [50]:
# This line is needed to load the local CodeBLEU library. Do not use it to export this notebook!!
sys.path.append('/workspaces/code-rationales/scripts')

In [51]:
## based on microsoft script for calculating codeBLEU in codeSearchNet
import CodeBLEU.bleu as bleu
import CodeBLEU.weighted_ngram_match as weighted_ngram_match
import CodeBLEU.syntax_match as syntax_match
import CodeBLEU.dataflow_match as dataflow_match

In [52]:
def calculate_bleu_codeBleu(lang,params,df, gt_col, pred_col, keywords):
    alpha,beta,gamma,theta = [float(x) for x in params.split(',')]
    # preprocess inputs
    pre_references = [df[gt_col].to_list()]
    hypothesis = df[pred_col].to_list()
    for i in range(len(pre_references)):
        assert len(hypothesis) == len(pre_references[i])

    references = []
    for i in range(len(hypothesis)):
        ref_for_instance = []
        for j in range(len(pre_references)):
            ref_for_instance.append(pre_references[j][i])
        references.append(ref_for_instance)
    assert len(references) == len(pre_references)*len(hypothesis)


    # calculate ngram match (BLEU)
    tokenized_hyps = [x.split() for x in hypothesis]
    tokenized_refs = [[x.split() for x in reference] for reference in references]

    ngram_match_score = bleu.corpus_bleu(tokenized_refs,tokenized_hyps)
    
    # calculate weighted ngram match
    keywords = [x.strip() for x in open(keywords, 'r', encoding='utf-8').readlines()]
    def make_weights(reference_tokens, key_word_list):
        return {token:1 if token in key_word_list else 0.2 \
                for token in reference_tokens}
    tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]\
                for reference_tokens in reference] for reference in tokenized_refs]

    weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights,tokenized_hyps)

    # calculate syntax match
    syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis,lang)

    # calculate dataflow match
    dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis,lang)

    logging.info('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.\
                        format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))

    code_bleu_score = alpha*ngram_match_score\
                    + beta*weighted_ngram_match_score\
                    + gamma*syntax_match_score\
                    + theta*dataflow_match_score

    logging.info('CodeBLEU score: '+ str(code_bleu_score))
    return ngram_match_score, code_bleu_score

In [53]:
lang='python'
keywords = '/workspaces/code-rationales/scripts/CodeBLEU/keywords/'+lang+'.txt'

In [54]:
samples=30
for i in range(0,samples):
    bleuScore, codebleuScore = calculate_bleu_codeBleu(lang,params,bart_df,'ground_truth','outcome_'+str(i),keywords)
    bart_avg_df.loc[i,'bleu'] = bleuScore
    bart_avg_df.loc[i,'codebleu'] = codebleuScore

In [55]:
bart_avg_df

Unnamed: 0,levenshtein,bleu,codebleu
0,0.301,0.063,0.181
1,0.305,0.06,0.182
2,0.303,0.06,0.187
3,0.302,0.063,0.183
4,0.306,0.063,0.182
5,0.299,0.059,0.181
6,0.304,0.063,0.184
7,0.305,0.061,0.182
8,0.3,0.062,0.181
9,0.3,0.063,0.18


In [56]:
bart_avg_df.describe()

Unnamed: 0,levenshtein,bleu,codebleu
count,30.0,30.0,30.0
mean,0.304,0.057,0.172
std,0.004,0.006,0.013
min,0.289,0.039,0.131
25%,0.301,0.055,0.167
50%,0.304,0.059,0.178
75%,0.306,0.062,0.181
max,0.317,0.066,0.187


### Calculate bootstraping for each metric

In [57]:
bart_avg_df.apply(lambda col: bootstrapping( col, np_func=np.median, size=500, flag_clean_nan = False ))

Unnamed: 0,levenshtein,bleu,codebleu
0,0.304,0.057,0.178
1,0.305,0.057,0.177
2,0.306,0.057,0.181
3,0.304,0.060,0.180
4,0.305,0.060,0.177
...,...,...,...
495,0.305,0.057,0.177
496,0.304,0.060,0.178
497,0.304,0.057,0.179
498,0.303,0.058,0.178


In [58]:
bart_avg_df.apply(lambda col: bootstrapping( col, np_func=np.mean, size=500, flag_clean_nan = False ))

Unnamed: 0,levenshtein,bleu,codebleu
0,0.305,0.057,0.172
1,0.304,0.057,0.172
2,0.304,0.058,0.173
3,0.305,0.056,0.175
4,0.305,0.057,0.172
...,...,...,...
495,0.305,0.057,0.172
496,0.304,0.059,0.175
497,0.304,0.056,0.173
498,0.304,0.058,0.175
