#### The code to evaluate EM.

In [8]:
import os
import json

def em_code(prediction, ground_truth):
    gt = ground_truth.rstrip()
    lines = prediction.splitlines()
    if len(lines) == 0:
        if gt == "":
            return True
        else:
            return False
    pred = lines[0].rstrip()
    return pred == gt
    
base_dir = '/repo_data/repo_FID/old_medium_split_evaluation/'
experiments = os.listdir(base_dir)
for exp in experiments:
    success = 0
    total = 0
    result_file = os.path.join(base_dir, exp, 'final_output.jsonl')
    if os.path.exists(result_file):
        with open(result_file, 'r') as f:
            data = json.load(f)
            for entry in data:
                if em_code(entry['prediction'], entry['target']):
                    success += 1
                total += 1
        print(f"{exp}: {success}/{total}, EM={success*100/total}")

codet5base-2048_pretrained_True: 9016/310703, EM=2.90180654837578
codet5base-512_pretrained_True: 2956/167737, EM=1.7622826210078872
codet5large-2048_pretrained_True: 10649/255611, EM=4.16609613827261
codet5large-512_pretrained_True: 3860/167737, EM=2.301221555172681
FID-codet5base-512-63_no-truncation-codex-last_True: 134827/271060, EM=49.740647827049365
FID-codet5base-768-32_no-truncation-direct_True: 83711/167737, EM=49.90610300649231
finetuned-codet5base-512_baseline_True: 20499/167737, EM=12.220917269296578


In [12]:
file_path = '/repo_data/finetuning_checkpoints/codet5-base-ntp-java/eval_cp251000_tAll_vAll_sl512_nep1_bspd32_dn2_graccs2_lr1e-4_wup100_wd005_disha_data_2/examples/0_1.json'
total=0
success=0
with open(file_path, 'r') as f:
    data = json.load(f)
    for entry in data['first_line']:
        if em_code(entry['prediction'], entry['label']):
            success += 1
        total += 1
print(f"{success}/{total}, EM={success*100/total}")
    

48838/167737, EM=29.115818215420568


#### Testing truncation from left.

In [2]:
import transformers
model_name = 'Salesforce/codet5-base'
tokenizer = transformers.RobertaTokenizer.from_pretrained(model_name)

input = "public static void main "
tokens = tokenizer(input, \
                    max_length=3, \
                    padding='max_length', \
                    return_tensors="pt", \
                    truncation=True, \
                    truncation_side='left',)
print(tokens['input_ids'][0])
print(tokenizer.decode(tokens['input_ids'][0]))

  from .autonotebook import tqdm as notebook_tqdm
Keyword arguments {'truncation_side': 'left'} not recognized.


tensor([  1, 482,   2])
<s>public</s>


In [4]:
def truncate_from_left(text, text_maxlen):
    tokens = tokenizer(text, truncation=False).input_ids
    if len(tokens) > text_maxlen:
        tokens = tokens[-text_maxlen:]
    return tokenizer.decode(tokens, skip_special_tokens=True)
        
contexts = [{'title':1, 'text':'public static void main'},\
            {'title': 2, 'text':'public static void main(String[] args) {'},\
            {'title': 3, 'text':'public static void main(String[] args) {System.out.println("Hello World");}'}]

for context in contexts:
    context['text'] = truncate_from_left(context['text'], 5)
print(contexts)

[{'title': 1, 'text': 'public static void main'}, {'title': 2, 'text': '[] args) {'}, {'title': 3, 'text': 'Hello World");}'}]


In [2]:
contexts = [{'title':1, 'text':'public static void main'}, \
    {'title': 2, 'text':'public static void main(String[] args) {'}, \
        {'title': 3, 'text':'public static void main(String[] args) {System.out.println("Hello World");}'}]
f = "rule_name:" + " {} " + "rule_context:" + " {}"   
passages = [f.format(c['title'], c['text']) for c in contexts]
print(passages)

['rule_name: 1 rule_context: public static void main', 'rule_name: 2 rule_context: public static void main(String[] args) {', 'rule_name: 3 rule_context: public static void main(String[] args) {System.out.println("Hello World");}']
