In [1]:

import logging
from pathlib import Path

In [2]:
from pathlib import Path
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import sys
import torch


pd.options.display.float_format = '{:.3f}'.format

In [3]:
#| export
logging.basicConfig(
    filename="../datax/logs/bart_valuation_log.txt",
    filemode='a',
    format='%(asctime)s : %(levelname)s : %(message)s', 
    level=logging.INFO
    )

In [4]:
def param_default():
    corpus = 'fm_fc_ms' #<-- Scope
    data_path = Path('/workspaces/code-rationales/semeru-datasets/athena_test/' + corpus + '/')
    data_path_raw = Path(data_path/ 'raw')
    return {
        'bpe_path' : '/workspaces/code-rationales/scripts/tokenizer/universal_tokenizer/roberta_aug_spaces',
        'eval_raw': [data_path_raw / 'eval/input.methods.txt',
                        data_path_raw / 'eval/output.tests.txt'],
        'test_raw': [data_path_raw / 'test/input.methods.txt', 
                        data_path_raw / 'test/output.tests.txt'],
        'train_raw': [data_path_raw / 'train/input.methods.txt', 
                        data_path_raw / 'train/output.tests.txt'],
        'data_labels' : ['test_raw'],#['eval_raw','test_raw','train_raw'], <----- Just Test
        'super_data_checkpoint' : data_path / 'pandas',
        'out_processed' : '/datasets/out_processed/',
        'model_name_or_path' : '/workspaces/code-rationales/data/bart-fairseq/checkpoint_dir_athena_ms/models/', #Model Path
        'checkpoint_file': 'checkpoint_best.pt', #Model
        'output_sample' : '/workspaces/code-rationales/data/sampling/bart/',
        'corpus': corpus
    }

In [5]:
#sys.path.clear()

In [6]:
from fairseq.models.transformer import TransformerModel
from tokenizers import ByteLevelBPETokenizer


In [7]:
def load_tokenizer(bpe_path):
    return ByteLevelBPETokenizer(str(bpe_path)+'-vocab.json',str(bpe_path)+'-merges.txt')

In [8]:
def lazy_decode(bpe_java):
    return bpe_java.replace(' ','').replace('Ġ',' ').replace('Ċ','\n')

In [9]:
sys_params = param_default()

### Load Model

*NOTE:*  Load the same checkpoint or model as the input data was created

In [10]:
#Loading a pretrain model
model = TransformerModel.from_pretrained(
  model_name_or_path = sys_params['model_name_or_path'],
  checkpoint_file = sys_params['checkpoint_file'],
)

In [11]:
## Move model to GPU if available and trigger evaluation mode
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [12]:
## If GPU busy the use CPU
#device = 'cpu'

In [13]:
model = model.to( device ) #WARNING, Verify the device before assigning to memory
model.eval()

GeneratorHubInterface(
  (models): ModuleList(
    (0): BARTModel(
      (encoder): TransformerEncoderBase(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(50348, 512, padding_idx=1)
        (embed_positions): SinusoidalPositionalEmbedding()
        (layers): ModuleList(
          (0): TransformerEncoderLayerBase(
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (dropout_module): FairseqDropout()
            (activation_dropout_module): FairseqDropout()
            (fc1): Linear(in_features=512,

### Load Tokenizer

In [14]:
tokenizer = load_tokenizer(sys_params['bpe_path'])

In [15]:
def prettify_java(minified_java):
    "tries to undo Michele's minification. Works decently, although for loops and sets get newlines inserted, and there are no empty lines or comments"
    minified_java = minified_java.replace('{','{\n').replace('}','}\n').replace(';',';\n')
    num_indents = 0
    pretty_java = ''
    for line in minified_java.splitlines():
        if line.lstrip().startswith('}'):
            num_indents -= 1
        pretty_java += num_indents*'    '+line+'\n'
        if line.endswith('{'):
            num_indents += 1
        if line.endswith('}') and not line.lstrip().startswith('}'):
            num_indents -= 1
    return pretty_java

## Data loading

In [16]:
corpus = sys_params['corpus']
super_data = pd.read_json(sys_params['output_sample']+corpus+'_generated.tests.json')

In [17]:
super_data.head()

Unnamed: 0,index,input,input_bpe,input_method_size,output,output_bpe,output_method_size,0,1,2,...,21,22,23,24,25,26,27,28,29,input_is
0,58524,CookbookLoader { public Cookbook load(InputStr...,"[Cook, book, Loader, Ġ{, Ġpublic, ĠCook, book,...",102,@Test public void testLoad_ingredientWithPrimi...,"[@, Test, Ġpublic, Ġvoid, Ġtest, Load, _, ing,...",89,"[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...",...,"[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 1640, 10162, 9089, 35529, 5457, ...","[1039, 34603, 285, 13842, 1296, 47167, 43048, ...","[1039, 34603, 285, 13842, 1296, 47167, 43048, ...","[32963, 6298, 49621, 25522, 285, 4350, 6298, 7..."
1,14594,SofaResponse implements Serializable { @Overri...,"[S, of, a, Response, Ġimplements, ĠSerial, iza...",240,@Test public void testToString() { SofaRespons...,"[@, Test, Ġpublic, Ġvoid, Ġtest, To, String, (...",32,"[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...",...,"[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[1039, 34603, 285, 13842, 1296, 3972, 34222, 4...","[104, 1116, 102, 47806, 36987, 42477, 38142, 2..."
2,19684,ImageViewTarget extends ViewTarget<ImageView> ...,"[Image, View, Target, Ġextends, ĠView, Target,...",238,@Test public void testOnError() { ImageView im...,"[@, Test, Ġpublic, Ġvoid, Ġtest, On, Error, ()...",210,"[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...",...,"[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[1039, 34603, 285, 13842, 1296, 4148, 30192, 4...","[8532, 22130, 41858, 14269, 3756, 41858, 41552..."
3,17039,BeanFactory { public static <T> T createBean(f...,"[Be, an, Factory, Ġ{, Ġpublic, Ġstatic, Ġ<, T,...",158,@Test(expected = JpaUnitException.class) publi...,"[@, Test, (, expected, Ġ=, ĠJ, pa, Unit, Excep...",42,"[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...",...,"[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1045, 9325, 260, 430...","[1039, 34603, 285, 13842, 1045, 9325, 260, 430...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[1039, 34603, 285, 13842, 1296, 44758, 9325, 2...","[9325, 260, 47249, 25522, 285, 25156, 28696, 5..."
4,58025,CryptoContainer { @WorkerThread public static ...,"[Crypt, o, Container, Ġ{, Ġ@, Work, er, Thread...",334,@Test public void open_fileDoesNotExist() thro...,"[@, Test, Ġpublic, Ġvoid, Ġopen, _, file, Does...",45,"[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[1039, 34603, 285, 13842, 1296, 25266, 43048, ...",...,"[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[1039, 34603, 1640, 10162, 5457, 27159, 48847,...","[1039, 34603, 1640, 10162, 5457, 27159, 48847,...","[1039, 34603, 1640, 10162, 5457, 27159, 48847,...","[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[1039, 34603, 1640, 10162, 5457, 27159, 48847,...","[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[1039, 34603, 285, 13842, 1296, 25266, 43048, ...","[44623, 139, 48557, 25522, 787, 21461, 254, 47..."


In [18]:
decoded = model.decode(super_data['0'].values[0])
decoded

'@ Test ( expected Ex ceptions Ġ= ĠRuntime Exception . class , Ġexpected Ex ceptions Message Reg Exp Ġ= ĠRuntime Exception . class , Ġexpected Ex ceptions Message Reg Exp Ġ= ĠRuntime Exception . class . get Name () Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+ Ġ" \\\\ n " Ġ+'

In [19]:
prettify_java( lazy_decode( decoded ) )

'@Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = RuntimeException.class, expectedExceptionsMessageRegExp = RuntimeException.class.getName() + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" +\n'

In [20]:
arr_prettify_generated = [prettify_java(lazy_decode(model.decode(input))) for input in super_data['0'].values]

In [21]:
arr_prettify_generated

['@Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = RuntimeException.class, expectedExceptionsMessageRegExp = RuntimeException.class.getName() + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" + "\\\\n" +\n',
 '@Test public void testToString() {\n     StringBuilder sb = new StringBuilder();\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg=");\n     sb.append("sofa-rpcErrorMsg

## Decoding each sample

In [22]:
bart_df = pd.DataFrame()

In [23]:
## In this case the ground truth is the ouput
bart_df['ground_truth'] = super_data['output'].values

In [24]:
SAMPLES = 30

In [25]:
for i in range(SAMPLES):
    bart_df['outcome_'+str(i)] = [prettify_java(lazy_decode(model.decode(input))) for input in super_data[str(i)].values]

In [26]:
bart_df

Unnamed: 0,ground_truth,outcome_0,outcome_1,outcome_2,outcome_3,outcome_4,outcome_5,outcome_6,outcome_7,outcome_8,...,outcome_20,outcome_21,outcome_22,outcome_23,outcome_24,outcome_25,outcome_26,outcome_27,outcome_28,outcome_29
0,@Test public void testLoad_ingredientWithPrimi...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test(expectedExceptions = RuntimeException.cl...,@Test public void testLoad() throws Exception ...,@Test public void testLoad() {\n }\n
1,@Test public void testToString() { SofaRespons...,@Test public void testToString() {\n Strin...,@Test public void testToString() {\n Strin...,@Test public void testToString() {\n Strin...,@Test public void testToString() throws Except...,@Test public void testToString() throws Except...,@Test public void testToString() {\n Strin...,@Test public void testToString() {\n Strin...,@Test public void testToString() throws Except...,@Test public void testToString() throws Except...,...,@Test public void testToString() {\n Strin...,@Test public void testToString() {\n Strin...,@Test public void testToString() {\n }\n,@Test public void testToString() {\n asser...,@Test public void testToString() throws Except...,@Test public void testToString() {\n asser...,@Test public void testToString() {\n asser...,@Test public void testToString() throws Except...,@Test public void testToString() throws Except...,@Test public void testToString() {\n SofaR...
2,@Test public void testOnError() { ImageView im...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...,@Test public void testOnError() {\n when(v...
3,@Test(expected = JpaUnitException.class) publi...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() {\n fin...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() {\n fin...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() throws Exce...,...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() {\n fin...,@Test public void testCreateBean() {\n fin...,@Test public void createBean() throws Exceptio...,@Test public void createBean() throws Exceptio...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() throws Exce...,@Test public void testCreateBean() {\n ass...,@Test public void testCreateBean() {\n ass...
4,@Test public void open_fileDoesNotExist() thro...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws CryptoExce...,@Test public void testOpen() throws Exception ...,...,@Test(expected = CryptoException.class) public...,@Test public void testOpen() throws CryptoExce...,@Test public void testOpen() throws Exception ...,@Test(expected = CryptoException.class) public...,@Test(expected = CryptoException.class) public...,@Test(expected = CryptoException.class) public...,@Test public void testOpen() throws Exception ...,@Test(expected = CryptoException.class) public...,@Test public void testOpen() throws Exception ...,@Test public void testOpen() throws Exception ...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,@Test public void testGetProperties() throws E...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...,@Test public void getProperties() {\n Syst...
96,@Test public void onAttachedToWindow_inEditMod...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void onAttachedToWindow() {\n ...,@Test public void testOnAttachedToWindow() {\n...,...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...,@Test public void testOnAttachedToWindow() {\n...
97,@Test public void shouldConvertIeShortFloatToF...,@Test public void testConvertMapping() {\n ...,@Test public void testConvert() {\n final ...,@Test public void testConvertMapping() {\n ...,@Test public void testConvertMapping() {\n ...,@Test public void testConvertMapping() {\n ...,@Test public void testConvert() {\n final ...,@Test public void testConvert() {\n final ...,@Test public void testConvert() {\n final ...,@Test public void testConvertMapping() {\n ...,...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...,@Test(expected = NullPointerException.class) p...
98,@Test public void testNonPrintable() { SimpleE...,@Test public void testDumpEvent() {\n asse...,@Test public void testDumpEvent() {\n Asse...,@Test public void testDumpEvent() {\n Asse...,@Test public void testDumpEvent() {\n asse...,@Test public void testDumpEvent() {\n asse...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() {\n Even...,@Test public void testDumpEvent() throws Excep...,@Test public void testDumpEvent() {\n asse...,@Test public void testDumpEvent() {\n Even...


## Levenshtein evaluation

In [27]:
import textdistance

In [28]:
levenshtein_similarity = textdistance.levenshtein


In [29]:
SIZE_SAMPLING = 146

In [30]:
def reduce_sequence_size():
    bart_df['ground_truth'] = bart_df['ground_truth'].values[:SIZE_SAMPLING]
    for i in range(SAMPLES):
        bart_df['outcome_'+str(i)] = bart_df['outcome_'+str(i)].values[:SIZE_SAMPLING]


Load from previous checkpoint if it exists

In [31]:
path= '/workspaces/code-rationales/semeru-datasets/athena_test/' +corpus +'/pandas/calc_lev_30_100_samples.parquet'
#bart_lev_df = pd.read_parquet(path)

In [32]:
## !!!! ALERT TIME COMSUMMING!!! 
# Use reduce_sequence_size to reeduce sequence size and computation time or load from SAVING!!
bart_lev_df = pd.DataFrame()
for i in range(SAMPLES):
    bart_lev_df['lev_'+str(i)] = [levenshtein_similarity.normalized_similarity(x["ground_truth"].strip(), x["outcome_"+str(i)].strip() ) for _,x in bart_df.iterrows()]

In [None]:
bart_lev_df

## Checkpoint saving 

In [None]:
bart_lev_df.to_parquet(path)

In [None]:
bart_lev_df.describe()

In [None]:
bart_avg_df = pd.DataFrame()
bart_avg_df['levenshtein'] = bart_lev_df.mean().values

In [None]:
bart_avg_df

## Bootstrapping

In [None]:
n_bootstraps = 30

In [None]:

#| export
import numpy as np
from statistics import NormalDist

In [None]:
#| export
def bootstrapping( col, np_func, size, flag_clean_nan = False ):
    """
    @size: number of bootstrapping samples
    @np_funct: numpy function for reducing the samples (e.g., median, mean, max)
    @flag_clean_nan: flag to eliminate Nan values in the np tensor
    """
    np_data = col.values
    col_name = col.name
    #Cleaning NaNs
    if flag_clean_nan:
        np_data = np_data[ np.logical_not( np.isnan(np_data) ) ] 
    
    #Creating the boostrap replicates as long as the original data size
    #This strategy might work as imputation 
    bootstrap_repl = [ np_func( np.random.choice( np_data, size=len(np_data) ) ) for i in range( size ) ]
    
    logging.info("Empirical Estimate {}: {}".format(col_name, str(np_func( np_data ))) ) #Empirical Mean,Median,Max, etc
    logging.info("Bootstrapped Estimate {}: {} ".format(col_name, str( np_func( bootstrap_repl ) ) )) #Bootstrapped Mean,Median,Max, etc
    
    return np.array( bootstrap_repl )

In [None]:
def confidence_intervals_large_samples(data, confidence=0.95):
    """
    @confidence: confidence interval 
    @return: tuple (lowerbound, uperbound, h-value)
    """
    dist = NormalDist.from_samples( data )
    z = NormalDist().inv_cdf((1 + confidence) / 2.)
    h = dist.stdev * z / ((len(data) - 1) ** .5)
    return dist.mean - h, dist.mean + h, h

In [None]:
def standard_error(bootstrapped_data):
    return np.std( bootstrapped_data )

In [None]:
lev_median_np = bootstrapping( bart_avg_df.levenshtein.values, np_func=np.median, size=500, flag_clean_nan = False ) #Bootstrapped Complexity


In [None]:
lev_mean_np = bootstrapping( bart_avg_df.levenshtein.values, np_func=np.mean, size=500, flag_clean_nan = False ) #Bootstrapped Complexity


In [None]:
#Bootrapped Estimates
np.median( lev_median_np ) , np.mean( lev_mean_np )

In [None]:
standard_error(lev_median_np), standard_error(lev_mean_np)


In [None]:
test_confidence_lev_median = confidence_intervals_large_samples(data = lev_median_np, confidence=0.95)


In [None]:
test_confidence_lev_median

## 2.0 Calculating BLUE and codeBLEU

In [None]:
## Params for codebleu: alpha, beta, gamma, theta
params='0.25,0.25,0.25,0.25'
lang= 'python'

In [None]:
# This line is needed to load the local CodeBLEU library. Do not use it to export this notebook!!
sys.path.append('/workspaces/code-rationales/scripts')

In [None]:
## based on microsoft script for calculating codeBLEU in codeSearchNet
import CodeBLEU.bleu as bleu
import CodeBLEU.weighted_ngram_match as weighted_ngram_match
import CodeBLEU.syntax_match as syntax_match
import CodeBLEU.dataflow_match as dataflow_match

In [None]:
def calculate_bleu_codeBleu(lang,params,df, gt_col, pred_col, keywords):
    alpha,beta,gamma,theta = [float(x) for x in params.split(',')]
    # preprocess inputs
    pre_references = [df[gt_col].to_list()]
    hypothesis = df[pred_col].to_list()
    for i in range(len(pre_references)):
        assert len(hypothesis) == len(pre_references[i])

    references = []
    for i in range(len(hypothesis)):
        ref_for_instance = []
        for j in range(len(pre_references)):
            ref_for_instance.append(pre_references[j][i])
        references.append(ref_for_instance)
    assert len(references) == len(pre_references)*len(hypothesis)


    # calculate ngram match (BLEU)
    tokenized_hyps = [x.split() for x in hypothesis]
    tokenized_refs = [[x.split() for x in reference] for reference in references]

    ngram_match_score = bleu.corpus_bleu(tokenized_refs,tokenized_hyps)
    
    # calculate weighted ngram match
    keywords = [x.strip() for x in open(keywords, 'r', encoding='utf-8').readlines()]
    def make_weights(reference_tokens, key_word_list):
        return {token:1 if token in key_word_list else 0.2 \
                for token in reference_tokens}
    tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]\
                for reference_tokens in reference] for reference in tokenized_refs]

    weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights,tokenized_hyps)

    # calculate syntax match
    syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis,lang)

    # calculate dataflow match
    dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis,lang)

    logging.info('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.\
                        format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))

    code_bleu_score = alpha*ngram_match_score\
                    + beta*weighted_ngram_match_score\
                    + gamma*syntax_match_score\
                    + theta*dataflow_match_score

    logging.info('CodeBLEU score: '+ str(code_bleu_score))
    return ngram_match_score, code_bleu_score

In [None]:
lang='python'
keywords = '/workspaces/code-rationales/scripts/CodeBLEU/keywords/'+lang+'.txt'

In [None]:
samples=30
for i in range(0,samples):
    bleuScore, codebleuScore = calculate_bleu_codeBleu(lang,params,bart_df,'ground_truth','outcome_'+str(i),keywords)
    bart_avg_df.loc[i,'bleu'] = bleuScore
    bart_avg_df.loc[i,'codebleu'] = codebleuScore

In [None]:
bart_avg_df

In [None]:
bart_avg_df.describe()

### Calculate bootstraping for each metric

In [None]:
bart_avg_df.apply(lambda col: bootstrapping( col, np_func=np.median, size=500, flag_clean_nan = False ))

In [None]:
bart_avg_df.apply(lambda col: bootstrapping( col, np_func=np.mean, size=500, flag_clean_nan = False ))