In [25]:
%load_ext autoreload
%autoreload 2

from coeditor.common import *
import os

from coeditor.dataset import TokenizedEditDataset, is_repetitive_edit
from coeditor.model import *
from coeditor.encoding import TokenizedEdit, decode_tokens, WindowArgs, tokens_to_change
import shutil
import random

os.chdir(proj_root())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
wargs = WindowArgs(4096)

# test_data_name = "medium"
test_data_name = "SPOT"
test_data: TokenizedEditDataset = pickle_load(
    get_dataset_dir(test_data_name) / "tokenized-file_based" / "test.pkl"
)
test_data = test_data.map(lambda e: e.truncate_ctx(wargs))

rep_edits = pfilter(is_repetitive_edit, test_data.all_edits())
print("Total edits: ", len(list(test_data.all_edits())))
print("Repetitive edits: ", len(rep_edits))


filtering: 100%|██████████| 1296/1296 [00:00<00:00, 2077.33it/s]


Total edits:  1296
Repetitive edits:  96


In [10]:
model_dir = get_model_dir(trained=True) / "coeditor-medium"
# model_dir = get_model_dir(trained=True) / "coeditor-small-skip"
model = CoeditorModel.load_pretrained(model_dir)
model.to("cuda:1")

eval_args = EvalArgs(4096 * 2, WindowArgs(4096))


In [14]:
rep_data = TokenizedEditDataset.from_edits(rep_edits)
rep_result = model.eval_on_data(rep_data, eval_args)
display(rep_result)

evaluate loss: 100%|██████████| 47/47 [00:30<00:00,  1.53batch/s]


{'loss_per_ex': (mean=3.1937, weight=96),
 'loss_per_tk': (mean=0.055383, weight=5536),
 'prob_per_ex': (mean=0.45247, weight=96)}

In [19]:
all_result = model.eval_on_data(test_data, eval_args)
display(all_result)


evaluate loss: 100%|██████████| 663/663 [05:32<00:00,  1.99batch/s]


{'loss_per_ex': (mean=41.051, weight=2398),
 'loss_per_tk': (mean=0.76586, weight=128535),
 'prob_per_ex': (mean=0.035363, weight=2398)}

In [7]:
max_saved_samples = 200

out_dir = Path("output/inspect_coeditor") / model_dir.name
shutil.rmtree(out_dir, ignore_errors=True)
(out_dir / "correct").mkdir(parents=True, exist_ok=True)
(out_dir / "incorrect").mkdir(parents=True, exist_ok=True)

decode_args = DecodingArgs()
exact_match = WeightedSum(0, 0)
predictions = []

for i, ex in enumerate(tqdm(rep_edits)):
    pred_tks = model.predict(ex.input_tks, decode_args)
    predictions.append(pred_tks)
    pred_change = TokenizedEdit(ex.input_tks, pred_tks, ex.path).as_change(True)
    truth_change = ex.as_change(True)
    is_correct = normalize_code_by_ast(pred_change.after) == normalize_code_by_ast(truth_change.after)
    exact_match += WeightedSum(int(is_correct), 1)

    if i < max_saved_samples:
        compare_str = ex.show_prediction(pred_tks)
        out_file = out_dir / ("correct" if is_correct else "incorrect") / f"ex-{i}.txt"
        out_file.write_text(compare_str)

print("Exact match: ", exact_match.average())


100%|██████████| 96/96 [04:03<00:00,  2.54s/it]

Exact match:  0.5833333333333334



