In [1]:
from datasets import load_from_disk, load_metric
from transformers import (
    DataCollatorForSeq2Seq,
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
import evaluate
from CodeBLEU.calc_code_bleu import compute_code_bleu

In [2]:
dataset = load_from_disk("/data/nicolasmaier/dataset/hf_clean_seq_dataset_3_eval")
print(dataset)

Dataset({
    features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'seq', 'labels', 'generated', 'generated_decoded'],
    num_rows: 21563
})


In [3]:
metric_exactmatch = evaluate.load("exact_match")
print(metric_exactmatch.features)

{'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}


In [4]:
res = metric_exactmatch.compute(predictions=dataset["generated_decoded"], references=dataset["seq"])
print(res)


{'exact_match': 0.9825163474470158}


In [8]:
res = [x["generated_decoded"] == x["seq"] for x in dataset]
print(sum(res), len(res), sum(res)/len(res))


21186 21563 0.9825163474470158
20909 21563 0.9696702685155126


In [15]:
res = compute_code_bleu(
    ref=dataset["seq"],
    hyp=dataset["generated_decoded"],
    lang="json",
    params=[1/3, 1/3, 1/3, 0], # no dataflow information
)
print(res)


{'code_bleu_score': 0.9988401176314992, 'ngram_match_score': 0.9985903974199987, 'weighted_ngram_match_score': 0.9985867559506246, 'syntax_match_score': 0.9993431995238742}


In [16]:
dataset_mistakes = dataset.filter(lambda x: x["generated_decoded"] != x["seq"])
res = compute_code_bleu(
    ref=dataset_mistakes["seq"],
    hyp=dataset_mistakes["generated_decoded"],
    lang="json",
    params=[1/3, 1/3, 1/3, 0], # no dataflow information
)
print(res)


Loading cached processed dataset at /data/nicolasmaier/dataset/hf_clean_seq_dataset_3_eval/cache-f5fdabeba1d1df7d.arrow


{'code_bleu_score': 0.9531869440870048, 'ngram_match_score': 0.9436915950432125, 'weighted_ngram_match_score': 0.9436201246008289, 'syntax_match_score': 0.9722491126169732}
