In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [6]:
import torch

from spot.model import ModelSPOT, TokenizerSPOT

train_from_scrach = True

model_path = "Salesforce/codet5-base" if train_from_scrach else datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
# model_path = datadir / "checkpoints/saved/SPOT-DAgger-scratch"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(model_path)
model: ModelSPOT = ModelSPOT.from_pretrained(model_path).to(device)



In [7]:
from IPython.display import display, display_pretty

import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs, CtxArgs

test_run = False
test_tag = 'test-' if test_run else ''

scratch_tag = '-scratch' if train_from_scrach else ''
model_name = f"{test_tag}SPOT-DAgger{scratch_tag}"

args = DAggerTrainerArgs(
    output_dir=datadir / "checkpoints" / model_name,
    max_epochs=3,
    skip_first_eval=False,
    repos_group_size=16,
    ctx_args=CtxArgs(
        ctx_size=512,
        ctx_margin=128,
        types_in_ctx=False,
    ),
    sampling_batch_size=300,
    train_batch_size=42,
    generation_max_length=128,
    max_workers=16,
)


trainer = DAggerTrainer(model, tokenizer, args)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"]]
if test_run:
    train_repos = train_repos[:10]
    valid_repos = valid_repos[:5]

In [4]:
wandb.init(project=model_name, config=args, dir=str(datadir))

try:
    trainer.train(train_repos, valid_repos)
except Exception as e:
    wandb.alert(title="Training stopped due to exception", text=f"In {model_name}, exception: {e}")
    raise e
wandb.alert(title="Training finished", text=f"{model_name} has finished.")
wandb.log({"time_stats": trainer.timer.total_times()})
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


DAgger Training:   0%|          | 0/1719 [00:00<?, ?it/s]

[Epoch 0] R0 stats:


{'R0_accuracy_partial': {'total': 0.058035498379842336,
  'FuncReturn': 0.06209987195902689,
  'FuncArg': 0.06948910325116113,
  'ClassAtribute': 0.010392609699769052,
  'LocalVar': 0.009784735812133072,
  'GlobalVar': 0.016129032258064516},
 'R0_accuracy_full': {'total': 0.051071238574261255,
  'FuncReturn': 0.0528169014084507,
  'FuncArg': 0.062433011789924976,
  'ClassAtribute': 0.008852963818321785,
  'LocalVar': 0.007827788649706457},
 'R0_n_labels': 20677}

[Epoch 0] R1 stats:


{'R1_accuracy_partial': {'total': 0.03617721029212614,
  'FuncArg': 0.04966502903081733,
  'FuncReturn': 0.025448143405889884,
  'ClassAtribute': 0.009237875288683603,
  'LocalVar': 0.01761252446183953},
 'R1_accuracy_full': {'total': 0.03201779841361965,
  'FuncArg': 0.04493077266636891,
  'FuncReturn': 0.02112676056338028,
  'ClassAtribute': 0.008852963818321785,
  'LocalVar': 0.007827788649706457},
 'R1_n_labels': 20676}



[Epoch 1] R0 stats:


{'R0_accuracy_partial': {'total': 0.6701165546259128,
  'FuncArg': 0.5856555912826009,
  'ClassAtribute': 0.6882217090069284,
  'FuncReturn': 0.81354033290653,
  'LocalVar': 0.6340508806262231,
  'GlobalVar': 0.8387096774193549},
 'R0_accuracy_full': {'total': 0.5960245683609808,
  'FuncArg': 0.5156305823508396,
  'ClassAtribute': 0.6127790608160123,
  'FuncReturn': 0.7413572343149808,
  'LocalVar': 0.4774951076320939,
  'GlobalVar': 0.6693548387096774},
 'R0_n_labels': 20677}

[Epoch 1] R1 stats:


{'R1_accuracy_partial': {'total': 0.672228670922809,
  'FuncArg': 0.590174184903975,
  'ClassAtribute': 0.6897613548883756,
  'FuncReturn': 0.8130601792573624,
  'LocalVar': 0.6340508806262231,
  'GlobalVar': 0.7741935483870968},
 'R1_accuracy_full': {'total': 0.5962952215128652,
  'FuncArg': 0.5173738276016079,
  'ClassAtribute': 0.6143187066974596,
  'FuncReturn': 0.7419974391805377,
  'LocalVar': 0.4500978473581213,
  'GlobalVar': 0.6048387096774194},
 'R1_n_labels': 20676}



[Epoch 2] R0 stats:


{'R0_accuracy_partial': {'total': 0.6805145814189679,
  'ClassAtribute': 0.7236335642802155,
  'FuncArg': 0.5936941764916042,
  'FuncReturn': 0.8196222791293214,
  'LocalVar': 0.6301369863013698,
  'GlobalVar': 0.8145161290322581},
 'R0_accuracy_full': {'total': 0.6068094984765682,
  'ClassAtribute': 0.6535796766743649,
  'FuncArg': 0.5205430510896749,
  'FuncReturn': 0.7511203585147247,
  'LocalVar': 0.4735812133072407,
  'GlobalVar': 0.6935483870967742},
 'R0_n_labels': 20677}

[Epoch 2] R1 stats:


{'R1_accuracy_partial': {'total': 0.6881892048752176,
  'ClassAtribute': 0.7205542725173211,
  'FuncArg': 0.6061634658329611,
  'FuncReturn': 0.8233034571062741,
  'GlobalVar': 0.782258064516129,
  'LocalVar': 0.6457925636007827},
 'R1_accuracy_full': {'total': 0.612932869026891,
  'ClassAtribute': 0.6443418013856813,
  'FuncArg': 0.531308619919607,
  'FuncReturn': 0.7544814340588989,
  'LocalVar': 0.5048923679060665,
  'GlobalVar': 0.6370967741935484},
 'R1_n_labels': 20676}

wandb: Network error (ConnectTimeout), entering retry loop.


[Epoch 3] R0 stats:


{'R0_accuracy_partial': {'total': 0.6874304783092324,
  'ClassAtribute': 0.7209391839876829,
  'FuncArg': 0.6067345480528761,
  'FuncReturn': 0.8204225352112676,
  'LocalVar': 0.62426614481409,
  'GlobalVar': 0.8306451612903226},
 'R0_accuracy_full': {'total': 0.6153213715722784,
  'ClassAtribute': 0.6505003849114703,
  'FuncArg': 0.5352804573061808,
  'FuncReturn': 0.7520806658130602,
  'LocalVar': 0.4911937377690802,
  'GlobalVar': 0.7258064516129032},
 'R0_n_labels': 20677}

[Epoch 3] R1 stats:


{'R1_accuracy_partial': {'total': 0.6973302379570516,
  'ClassAtribute': 0.737490377213241,
  'FuncArg': 0.6190263510495757,
  'FuncReturn': 0.8250640204865557,
  'GlobalVar': 0.8548387096774194,
  'LocalVar': 0.6086105675146771},
 'R1_accuracy_full': {'total': 0.6259914877152254,
  'ClassAtribute': 0.6655119322555813,
  'FuncArg': 0.5493523894595802,
  'FuncReturn': 0.7573623559539052,
  'LocalVar': 0.4774951076320939,
  'GlobalVar': 0.7096774193548387},
 'R1_n_labels': 20676}



In [8]:
r0_stats, r1_stats, ds, preds = trainer.eval_on_repos(valid_repos[1:20])
display(r1_stats)

{'R1_partial_acc_no_any': 0.42105263157894735,
 'R1_partial_acc': 0.7160480349344979,
 'R1_partial_accs': {'FuncArg': 0.6601084119654688,
  'FuncReturn': 0.7658889782783588,
  'ClassAtribute': 0.8468660968660968,
  'GlobalVar': 0.7115384615384616,
  'LocalVar': 0.5949367088607594},
 'R1_full_acc': 0.6512008733624454,
 'R1_full_accs': {'FuncArg': 0.5884360570166633,
  'FuncReturn': 0.6983105390185036,
  'ClassAtribute': 0.8183760683760684,
  'GlobalVar': 0.5961538461538461,
  'LocalVar': 0.4978902953586498},
 'R1_n_labels': 9160}

In [12]:
from spot.training import CtxArgs
ctx_args_larger = CtxArgs(1024, 256, False)
trainer.args.sampling_batch_size = 64
r0_stats, r1_stats, ds, preds = trainer.eval_on_repos(valid_repos[1:20], ctx_args_larger)
trainer.args.sampling_batch_size = 300
display(r1_stats)

parsing and masking sources:   0%|          | 0/1093 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/1093 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/2767 [00:00<?, ?it/s]

predict:   0%|          | 0/1336 [00:00<?, ?it/s]

reading orginal srcs:   0%|          | 0/527 [00:00<?, ?it/s]

calling mypy:   0%|          | 0/19 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/527 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/527 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1892 [00:00<?, ?it/s]

predict:   0%|          | 0/1457 [00:00<?, ?it/s]

{'R1_partial_acc_no_any': 0.3881578947368421,
 'R1_partial_acc': 0.737117903930131,
 'R1_partial_accs': {'FuncArg': 0.6906243726159406,
  'FuncReturn': 0.7972646822204345,
  'ClassAtribute': 0.8212250712250713,
  'GlobalVar': 0.7692307692307693,
  'LocalVar': 0.5780590717299579},
 'R1_full_acc': 0.6697598253275109,
 'R1_full_accs': {'FuncArg': 0.6169443886769725,
  'FuncReturn': 0.7236524537409493,
  'ClassAtribute': 0.7927350427350427,
  'GlobalVar': 0.6346153846153846,
  'LocalVar': 0.4936708860759494},
 'R1_n_labels': 9160}

In [16]:
from spot import PythonType
from spot.data import TypeInfDataset, inline_predictions


def visualize_batch(dataset: TypeInfDataset, preds: list[list[PythonType]], i: int):
    types = preds[i]
    typpes_enc = [tokenizer.encode(str(t), add_special_tokens=False) for t in types]

    code_tks = inline_predictions(dataset.data["input_ids"][i], typpes_enc, tokenizer)
    code_dec = tokenizer.decode(code_tks, skip_special_tokens=False)
    label_dec = dataset.chunks_info[i].types
    return "".join([
        "labels: ", str(label_dec), "\n",
        "preds: ", str(types), "\n",
        "========================== Code =======================\n", code_dec, "\n",
    ])

from spot.visualization import display_code_sequence

display_code_sequence([visualize_batch(ds, preds, i) for i in range(12)])

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…

In [None]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
3,training > model fitting,7,153.695309,1075.867161
1,training > model prediction,8,84.873097,678.984775
2,training > type checking,7,66.09076,462.635319
0,training > preparing data,15,12.201682,183.025235
