In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

from spot.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
)
from copy import copy
from spot.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
gpu_id = 1

project_name = "test-SPOT" if config.quicktest else "SPOT"
# train_ctx_args = config.train_ctx_args()

max_tokens_per_file = config.ctx_size

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

r0_model_name = "R0-model--" + config._replace(quicktest=False).as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  src_datasets-all_labels-drop_comments


In [3]:
# load trained model
from spot.utils import pickle_load, pickle_dump
from spot.model import ModelWrapper


r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
# if train_R1:
#     r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
#     r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(ctx_size=4096, left_margin=2048, right_margin=1024), sampling_max_tokens=32768, max_workers=20)


In [5]:
# load the critics

from spot.critic import CriticModel, get_critic_name

critics = dict[bool, CriticModel]()
for new_data in [False]: # [True, False]:
    critic_name = get_critic_name(
        no_feedback=False, new_data=new_data, config=config._replace(quicktest=False)
    )
    critic = CriticModel.load(datadir / f"checkpoints/lit-saved/{critic_name}")
    critic.to(device)
    critics[new_data] = critic
print("Critics loaded.")


Critics loaded.


In [6]:
# set up the inference

from spot.model import DatasetPredResult
from spot.utils import pretty_print_dict, run_long_task, PickleCache
from spot.model import CtxArgs, DecodingArgs, ModelSPOT

testset = src_datasets["test"][1:-1:10]

# used for inference
n_samples = 16
dec_ctx_args = config.dec_ctx_args()
dec_ctx_args.max_labels = 1  # one type per chunk
greedy_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=dec_ctx_args,
    max_workers=28,
    do_sample=False,
)

sample_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=dec_ctx_args,
    max_workers=28,
    do_sample=True,
    top_p=0.9,
)

bs_args = DecodingArgs(
    sampling_max_tokens=max_tokens_per_file,
    ctx_args=dec_ctx_args,
    max_workers=28,
    do_sample=False,
    num_beams=n_samples,
)

bs_incr_args = DecodingArgs(
    ctx_args=dec_ctx_args,
    sampling_max_tokens=max_tokens_per_file,
    max_workers=28,
    max_tokens_per_type=16,
    do_sample=False,
    num_beams=n_samples,
)


eval_cache = PickleCache(proj_root() / "caches" / "inference_spot" / r0_model_name)
# eval_cache.clear()


In [7]:
# compute results
from spot.decode import (
    sample_candidates,
    select_candidates_by_type_errors,
    select_candidates_using_oracle,
    select_candidates_using_critic,
    select_first_candidates,
    incr_inference_with_feedback,
    SelectByOracle,
    SelectByCounting,
    SelectByCritic,
)

# def score_transform(x: float):
#     if x <= 0.1:
#         return -1.0
#     if x >= 0.9:
#         return 1.0
#     return 0.0

results = dict[str, DatasetPredResult]()
incr_results = dict[str, Any]()

with run_long_task("Computing results"):
    # r0_wrapper.args = bs_args
    # results["BS"] = evaluate_model(r0_wrapper, None, testset, eval_cache=eval_cache, tc_args=tc_args)[0][1]

    r0_wrapper.args = bs_incr_args

    incr_results["IncrCount"] = eval_cache.cached(
        "Result-IncrCount",
        lambda: incr_inference_with_feedback(
            r0_wrapper,
            testset,
            beam_width=8,
            selector=SelectByCounting(),
            log_to=proj_root() / "caches/IncrCount-Examples",
        ),
    )

    incr_results["IncrCritic"] = eval_cache.cached(
        "Result-IncrCritic",
        lambda: incr_inference_with_feedback(
            r0_wrapper,
            testset,
            beam_width=8,
            selector=SelectByCritic(critics[False]),
            log_to=proj_root() / "caches/IncrCritic-Examples",
        ),
    )

    # incr_results["IncrCritic-new"] = eval_cache.cached(
    #     "Result-IncrCritic-new",
    #     lambda: incr_inference_with_feedback(
    #         r0_wrapper,
    #         testset,
    #         beam_width=8,
    #         selector=SelectByCritic(critics[True]),
    #         log_to=proj_root() / "caches/IncrCritic-new-Examples",
    #     ),
    # )

    incr_results["IncrOracle"] = eval_cache.cached(
        "Result-IncrOracle",
        lambda: incr_inference_with_feedback(
            r0_wrapper,
            testset,
            beam_width=8,
            selector=SelectByOracle(),
            log_to=proj_root() / "caches/IncrOracle-Examples",
        ),
    )

    r0_wrapper.args = bs_args
    test_chunks, pred_candidates = eval_cache.cached(
        "sample_candidates",
        lambda: sample_candidates(r0_wrapper, testset, n_samples=n_samples),
    )

    results["BS"] = select_first_candidates(test_chunks, pred_candidates)

    results["Counting"] = eval_cache.cached(
        "Result-Counting",
        lambda: select_candidates_by_type_errors(testset, test_chunks, pred_candidates),
    )

    critic = critics[False]
    r_name = "Critic"
    results[r_name] = eval_cache.cached(
        f"Result-{r_name}",
        lambda: select_candidates_using_critic(
            critic,
            False,
            testset,
            test_chunks,
            pred_candidates,
            dec_args=greedy_args,
            # score_transform=score_transform,
        ),
    )

    results["Oracle"] = eval_cache.cached(
        "Result-Oracle",
        lambda: select_candidates_using_oracle(test_chunks, pred_candidates),
    )


Starting task: Computing results


incr_inference [SelectByCritic]:   1%|          | 22/1884 [01:05<1:32:14,  2.97s/it]


KeyboardInterrupt: 

In [None]:
from spot.visualization import display_persist, visualize_dicts
from spot.data import src_preds_to_accuracies
from spot.visualization import display_persist, dict_widget


accs_list = [x.accuracies for x in results.values()]
titles = list(results.keys())

for n, r in incr_results.items():
    accs = src_preds_to_accuracies(r[1], r[0])
    accs_list.append(accs)
    titles.append(n)

display_persist(visualize_dicts(accs_list, titles))


In [29]:
from spot.debug_critic import check_delta
from spot.utils import pretty_print_dict

delta_stats = check_delta(
    results["BS"], results["Oracle"], results[critic_result_name(False)]
)
pretty_print_dict(delta_stats, max_show_level=0)


diff_ratio: 0.085987
diff_critic_error: 0.47598
all_critic_error: 0.067515
diff_critic_abs_error: 0.56754
all_critic_abs_error: 0.26956
diff_scores_distr: ...
all_scores_distr: ...


In [None]:
from spot.visualization import visualize_counts


def display_score_dist(score_counts: Counter):
    display(visualize_counts(score_counts, "critic score", top_k=score_counts.keys()))


display_score_dist(delta_stats["diff_scores_distr"])
display_score_dist(delta_stats["all_scores_distr"])


In [None]:
from spot.debug_critic import inspect_critic_on_beams
from spot.visualization import dict_widget, display_persist


def check_critic(no_feedback: bool):
    sample_eval = results[critic_result_name(no_feedback)]
    r = inspect_critic_on_beams(sample_eval, pred_candidates)
    display_persist(dict_widget(r))


check_critic(False)


In [None]:
from spot.utils import pd, display

grouped_res = results["BS + critic-False"].group_by_repo()
grouped_full_acc = {k: v.accuracies["full_acc"] for k, v in grouped_res.items()}
repos = list(grouped_full_acc.keys())
repos.sort(key=lambda x: grouped_full_acc[x].n_total, reverse=True)

grouped_acc_bs = {
    k: v.accuracies["full_acc"] for k, v in results["BS"].group_by_repo().items()
}
grouped_oracle_bs = {
    k: v.accuracies["full_acc"]
    for k, v in results["BS + oracle"].group_by_repo().items()
}

table = pd.DataFrame(
    {
        "Repo": [r.name for r in repos],
        "BS": [str(grouped_acc_bs[r]) for r in repos],
        "Critic": [str(grouped_full_acc[r]) for r in repos],
        "Oracle": [str(grouped_oracle_bs[r]) for r in repos],
    }
)
display(table)


Unnamed: 0,Repo,BS,Critic,Oracle
0,basilisp-lang__basilisp,47.43% (count=3.4k),47.14% (count=3.4k),53.56% (count=3.4k)
1,kornicameister__axion,47.11% (count=1.2k),46.08% (count=1.2k),54.52% (count=1.2k)
2,nabla-c0d3__sslyze,74.02% (count=1.1k),71.90% (count=1.1k),82.00% (count=1.1k)
3,marcosschroh__dataclasses-avroschema,60.10% (count=822),57.06% (count=822),69.46% (count=822)
4,scalableminds__webknossos-connect,67.80% (count=736),68.61% (count=736),77.31% (count=736)
5,rakitaj__daily-programmer,76.50% (count=634),79.65% (count=634),91.48% (count=634)
6,seattleflu__id3c,59.58% (count=621),57.17% (count=621),72.46% (count=621)
7,nubark__instark,84.91% (count=570),85.09% (count=570),96.14% (count=570)
8,lucaswerkmeister__tool-quickcategories,79.62% (count=422),77.96% (count=422),86.73% (count=422)
9,paulcwatts__drf-json-schema,66.25% (count=400),67.50% (count=400),76.25% (count=400)


In [42]:
from spot.utils import not_none
from spot.visualization import visualize_preds_on_code

critic_eval = results[critic_result_name(False)]
preds_extra = {
    "critic_preds": [
        x.candidate_label_scores[x.best_candidate] for x in critic_eval.extra_info
    ]
}
visualize_preds_on_code(critic_eval.chunks, critic_eval.predictions, preds_extra)


VBox(children=(IntSlider(value=0, continuous_update=False, max=178), VBox(children=(VBox(children=(Box(childre…

In [None]:
from spot.decode import collect_type_errors_from_predictions
from spot.model import DatasetPredResult
from spot.type_check import PythonType, MypyFeedback
from spot.data import SrcDataset


def collect_base_errors(dataset: SrcDataset):
    "Collect the type errors triggered by replacing all labels with `Any`."
    chunks = dataset.to_chunks(
        r0_wrapper.tokenizer, r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    dummy_preds = [
        [PythonType(("Any",)) for _ in info.types] for info in chunks.chunks_info
    ]
    pred_r = DatasetPredResult(chunks, dummy_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)


def collect_gold_errors(dataset: SrcDataset):
    "Collect the type errors triggered by ground-truth labels."
    chunks = dataset.to_chunks(
        r0_wrapper.tokenizer, r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    label_preds = [info.types for info in chunks.chunks_info]
    pred_r = DatasetPredResult(chunks, label_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)


num_labels = sum(len(s.types) for s in testset.srcs_with_labels())
print("Total number of labels: ", num_labels)
type_errors = dict[str, list[tuple[Path, MypyFeedback]]]()
type_errors["default"] = collect_base_errors(testset)
type_errors["gold"] = collect_gold_errors(testset)
for k, v in results.items():
    type_errors[k] = collect_type_errors_from_predictions(testset, v, max_workers=30)

from spot.visualization import dict_widget, display_persist

display_persist(dict_widget({k: len(v) for k, v in type_errors.items()}))


In [None]:
from spot.visualization import seq_flatten, visualize_counts
from spot.utils import Counter
from spot.type_check import count_type_frequency


def show_type_distr(recursive: bool, top_k: int):
    counts = dict[str, Counter]()
    for name in ["greedy", "BS + feedback"]:
        types = seq_flatten(results[name].predictions)
        counts[name] = count_type_frequency(types, recursive=recursive)

    display(visualize_counts(counts, x_name="Predicted Type", top_k=top_k))


show_type_distr(recursive=True, top_k=15)


In [None]:
from spot.visualization import visualize_counts, visualize_sequence_tabs, display
from spot.utils import Counter

default_counts = Counter(e.error_code for _, e in type_errors["default"])

error_counts = dict[str, Counter]()
for name in ["gold"]:  # ["greedy", "BS + feedback"]:
    c = Counter(e.error_code for _, e in type_errors[name])
    for e, v in default_counts.items():
        c[e] -= v
    error_counts[name] = c
display(visualize_counts(error_counts, "Error"))


In [None]:
from spot.visualization import visualize_conf_matrix

visualize_conf_matrix(results)
