In [7]:
%load_ext autoreload
%autoreload 2

from coeditor.common import *
import os

from coeditor.dataset import TokenizedEditDataset
from coeditor.model import CoeditorModel, EvalArgs, DecodingArgs, input_cost_model
from coeditor.encoding import TokenizedEdit, decode_tokens, tokens_to_change, AnalysisBasedEditEncoder, CstBasedEditEncoder
from coeditor.history import Added, Modified
import shutil
import random
from prepare_data import make_or_load_datasets, dataset_from_projects, get_commit_history

os.chdir(proj_root())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# test_data_name = "medium"
test_data_name = "SPOT"
encoder = AnalysisBasedEditEncoder(extra_ctx_names=("usees", "post-usees"))
test_data = dataset_from_projects([proj_root()], encoder, [False], max_history_per_repo=100)
print("num addition: ", sum(isinstance(e.change_type, Added) for e in test_data.all_edits()))
print("num modifications: ", sum(isinstance(e.change_type, Modified) for e in test_data.all_edits()))

Getting commit histories: 100%|██████████| 1/1 [00:01<00:00,  1.17s/repo]
Create tokenized edits: 100%|██████████| 2/2 [03:26<00:00, 103.07s/chunk]

num addition:  0
num modifications:  639





In [2]:
test_data_name = "medium"
# test_data_name = "SPOT"
encoder = AnalysisBasedEditEncoder(extra_ctx_names=("usees", "post-usees"))
datasets = make_or_load_datasets(test_data_name, encoder, predict_added_in_training=False)
test_data = datasets["test"]

rep_edits = pfilter(TokenizedEdit.is_repetitive_edit, test_data.all_edits())
small_edits = pfilter(TokenizedEdit.is_small_edit, test_data.all_edits())
is_refactor_edits = [bool(e.updated_calls) for e in test_data.all_edits()]
refactor_edits = [e for e, is_refactor in zip(test_data.all_edits(), is_refactor_edits) if is_refactor]
print("Total edits: ", len(list(test_data.all_edits())))
print("Repetitive edits: ", len(rep_edits))
print("Small edits: ", len(small_edits))
print("Refactoring edits: ", len(refactor_edits))

Starting task: Loading datasets from disk
(11.7s) Finished task: Loading datasets from disk


filtering: 100%|██████████| 10178/10178 [00:06<00:00, 1621.72it/s]
filtering: 100%|██████████| 10178/10178 [00:01<00:00, 8050.52it/s]


Total edits:  10178
Repetitive edits:  648
Small edits:  6573
Refactoring edits:  868


In [3]:
test_data_name = "medium"
encoder = CstBasedEditEncoder()
datasets = make_or_load_datasets(test_data_name, encoder, predict_added_in_training=False)
test_data = datasets["test"]

rep_edits = pfilter(TokenizedEdit.is_repetitive_edit, test_data.all_edits())
small_edits = pfilter(TokenizedEdit.is_small_edit, test_data.all_edits())
print("Total edits: ", len(list(test_data.all_edits())))
print("Repetitive edits: ", len(rep_edits))
print("Small edits: ", len(small_edits))

Starting task: Loading datasets from disk
(7.5s) Finished task: Loading datasets from disk


filtering: 100%|██████████| 10178/10178 [00:05<00:00, 1867.63it/s]
filtering: 100%|██████████| 10178/10178 [00:01<00:00, 8552.27it/s]


Total edits:  10178
Repetitive edits:  598
Small edits:  6566
Refactoring edits:  0


In [8]:
model_dir = get_model_dir() / "coeditor-medium-sig-cst-no_added"
# model_dir = get_model_dir() / "coeditor-medium-analysis-post_usees"
model = CoeditorModel.load_pretrained(model_dir)
model.to("cuda:1")

eval_args = EvalArgs()
dec_args = DecodingArgs(num_beams=1)

In [4]:
eval_dir = model_dir / "evals" / test_data_name
eval_cache = PickleCache(eval_dir)

In [9]:
refactor_data = TokenizedEditDataset.from_edits(refactor_edits)
refactor_result = eval_cache.cached(
    "RefactorCallUpdate.pkl", 
    lambda: model.predict_on_data(refactor_data, eval_args, dec_args)
)
call_acc, call_correct = refactor_result.call_update_accuracy()
print("Call update accuracy: ", call_acc)

Preprocessing edits: 100%|██████████| 868/868 [00:00<00:00, 4182.41it/s]
Arranging batches: 100%|██████████| 868/868 [00:00<00:00, 401262.63it/s]
decoding: 100%|██████████| 248/248 [14:27<00:00,  3.50s/batch]


122 / 1176 calls were considered incorrect since they failed to parse.
Call update accuracy:  (mean=0.26871, weight=1176)


In [19]:
rep_data = TokenizedEditDataset.from_edits(rep_edits)
rep_result = eval_cache.cached(
    "RepetitiveEdits.pkl", 
    lambda: model.predict_on_data(rep_data, eval_args, dec_args)
)
rep_acc, rep_correct = rep_result.exact_match_accuracy()
print("Repetitive edits accuracy: ", rep_acc)

Preprocessing edits: 100%|██████████| 648/648 [00:00<00:00, 4186.05it/s]
Arranging batches: 100%|██████████| 648/648 [00:00<00:00, 474164.16it/s]


In [14]:
model.data_args.use_signature_prefix = True
refactor_data = TokenizedEditDataset.from_edits(refactor_edits)
eval_args = eval_args = EvalArgs(4096 * 4)
refactor_result = model.predict_on_data(refactor_data, eval_args, dec_args)
call_acc, call_correct = refactor_result.call_update_accuracy()
print("(use_signature_prefix) Call update accuracy: ", call_acc)

decoding: 100%|██████████| 38/38 [02:33<00:00,  4.05s/batch]


18 / 200 calls were considered incorrect since they failed to parse.
(use_signature_prefix) Call update accuracy:  (mean=0.39, weight=200)


In [18]:
out_dir = eval_dir / "CallUpdateAccuracy"
refactor_result.save_examples_to_dir(out_dir, call_correct)
print(out_dir)

saving examples: 100%|██████████| 868/868 [00:10<00:00, 85.05it/s] 

/mnt/nas/jiayi/coeditor/models/trained/coeditor-medium-sig-cst-no_added/evals/medium/CallUpdateAccuracy





In [9]:
rep_data = TokenizedEditDataset.from_edits(rep_edits)
rep_result = model.predict_on_data(rep_data, eval_args, dec_args)
display(rep_result.exact_match_accuracy()[0])

(mean=0.74667, weight=75)

In [7]:
dec_args = DecodingArgs(num_beams=1)
dec_result = eval_cache.cached("DatasetDecodingResult.pkl", lambda: model.predict_on_data(test_data, eval_args, dec_args))