In [1]:
%load_ext autoreload
%autoreload 2

from coeditor.common import *
import os

from coeditor.dataset import TokenizedEditDataset, is_repetitive_edit
from coeditor.model import CoeditorModel, _Tokenizer, DecodingArgs, EvalArgs
from coeditor.history import show_change
from coeditor.encoding import TokenizedEdit, decode_tokens, WindowArgs, tokens_to_change
import shutil
import random

os.chdir(proj_root())

In [2]:
n_samples = 50
wargs = WindowArgs(4096)

test_data_name = "small"
test_data: TokenizedEditDataset = pickle_load(
    get_dataset_dir(test_data_name) / "tokenized-file_based" / "test.pkl"
)
test_data = test_data.map(lambda e: e.truncate_ctx(wargs))

rep_edits = [e for e in test_data.all_edits() if is_repetitive_edit(e)]
print("Total edits: ", len(list(test_data.all_edits())))
print("Repetitive edits: ", len(rep_edits))

# random.seed(123)
# random.shuffle(rep_edits)
# rep_edits = rep_edits[:n_samples]


Total edits:  2398
Repetitive edits:  191


In [3]:
model_dir = get_model_dir(trained=True) / "small"
model = CoeditorModel.load_pretrained(model_dir)
model.to("cuda:1")

eval_args = EvalArgs(4096 * 2, WindowArgs(4096))


In [4]:
rep_data = TokenizedEditDataset.from_edits(rep_edits)
rep_result = model.eval_on_data(rep_data, eval_args)
display(rep_result)


evaluate loss: 100%|██████████| 57/57 [00:31<00:00,  1.84batch/s]


{'loss_per_ex': (mean=4.9499, weight=191),
 'loss_per_tk': (mean=0.10015, weight=9440),
 'prob_per_ex': (mean=0.38713, weight=191)}

In [16]:
all_result = model.eval_on_data(test_data, eval_args)
display(all_result)


evaluate loss: 100%|██████████| 663/663 [05:38<00:00,  1.96batch/s]


{'loss_per_ex': (mean=38.65, weight=2398),
 'loss_per_tk': (mean=0.53366, weight=173674),
 'prob_per_ex': (mean=0.061428, weight=2398)}

In [None]:
decode_args = DecodingArgs()
predictions = [model.predict(ex.input_tks, decode_args) for ex in tqdm(rep_edits)]

In [5]:
out_dir = Path("output/inspect_coeditor_model")
shutil.rmtree(out_dir, ignore_errors=True)
(out_dir / "correct").mkdir(parents=True, exist_ok=True)
(out_dir / "incorrect").mkdir(parents=True, exist_ok=True)

exact_match = WeightedSum(0, 0)

for i, ex in enumerate(tqdm(rep_edits)):
    pred_tks = predictions[i]
    pred_change = TokenizedEdit(ex.path, ex.input_tks, pred_tks).as_change(True)
    truth_change = ex.as_change(True)
    is_correct = normalize_code_by_ast(pred_change.after) == normalize_code_by_ast(truth_change.after)
    exact_match += WeightedSum(int(is_correct), 1)

    compare_str = ex.show_prediction(pred_tks)
    out_file = out_dir / ("correct" if is_correct else "incorrect") / f"ex-{i}.txt"
    out_file.write_text(compare_str)

print("Exact match: ", exact_match.average())


100%|██████████| 191/191 [05:28<00:00,  1.72s/it]

Exact match:  0.5497382198952879



