In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
from spot.data import SrcDataset

src_datasets_path = datadir / f"SPOT-data/src_datasets"
src_datasets = dict[str, SrcDataset]()
for n in ["train", "valid", "test"]:
    with open(src_datasets_path / f"{n}.pkl", "rb") as f:
        src_datasets[n] = pickle.load(f)
        src_datasets[n].repos_root = repos_dir


In [8]:
largest_file = max(src_datasets["train"].all_srcs, key=lambda x: len(x.tokenized_code))
print(len(largest_file.tokenized_code))
print(len(largest_file.origin_code))

471656
1061323


In [10]:
width_list = []
for src in src_datasets["train"].srcs_with_labels():
    lines = src.origin_code.split("\n")
    width = max(len(l) for l in lines)
    width_list.append(width)
px.histogram(pd.DataFrame({"max_width": width_list}), x="max_width")


In [None]:
import torch
from spot.model import ModelWrapper, ModelSPOT

with_margin = True
data_reduction = 1

margin_tag = "with_margin" if with_margin else "no_margin"
data_tag = "data_full" if data_reduction == 1 else f"data_1-{data_reduction}"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
r0_model_name = f"SPOT-R0-{margin_tag}-{data_tag}"
r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/saved/{r0_model_name}"
)
r0_wrapper.model.to(device)
tokenizer = r0_wrapper.tokenizer


In [5]:
# Set this to the best ctx_size
best_r0_ctx_factor = 3
r0_wrapper_best = r0_wrapper.scale_ctx_size(best_r0_ctx_factor)


In [6]:
from spot.utils import TaskLoggingMonitor
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper, TokenizerSPOT

train_r1 = True

r1_model_name = f"SPOT-R1-{margin_tag}-{data_tag}"

if train_r1:
    r1_model_path = "Salesforce/codet5-base"
else:
    r1_model_path = datadir / f"checkpoints/saved/{r1_model_name}"

r1_model: ModelSPOT = ModelSPOT.from_pretrained(r1_model_path).to(device)
r1_monitor = TaskLoggingMonitor("R1")
r1_args = DecodingArgs(
    sampling_batch_size=512,
    ctx_args=CtxArgs(
        ctx_size=512,
        ctx_margin=128,
        types_in_ctx=False,
    ),
    max_workers=20,
)
r1_wrapper = ModelWrapper(r1_model, tokenizer, r1_args, r1_monitor).scale_ctx_size(2)




In [14]:
import pickle

from spot.data import ChunkedDataset, save_datasets
from spot.utils import PickleCache

test_r1_generation = False
use_file_level_feedback = False

feedback_tag = "iso_file"  # "per_file" if use_file_level_feedback else "per_project"

r1_cache = PickleCache(datadir / f"cache/r1_src_datasets-{test_r1_generation}")

with run_long_task("Generating R1 datasets", notify=False):
    r1_src_datasets = dict()
    for name in ["test", "valid", "train"]:
        print("Working on:", name)
        r0_src = src_datasets[name]
        if test_r1_generation:
            r0_src = SrcDataset(r0_src.srcs_with_labels()[:16], r0_src.repos_root)
        _, r0_data, r0_preds = r1_cache.cached(
            f"eval_r0/{name}",
            lambda: r0_wrapper_best.eval_on_dataset(r0_src, tqdm_args={"leave": False}),
        )
        r1_src_datasets[name] = r1_cache.cached(
            f"r1_src_datasets/{name}",
            lambda: r1_wrapper.generate_r1_srcs(
                r0_src,
                r0_data,
                r0_preds,
                tqdm_args={"leave": False},
            ),
        )


Working on: test


processing chunks:   0%|          | 0/1896 [00:00<?, ?it/s]

predict:   0%|          | 0/1624 [00:00<?, ?it/s]

type_check_src:   0%|          | 0/891 [00:00<?, ?it/s]

feedbacks_to_tokenized_src:   0%|          | 0/891 [00:00<?, ?it/s]

Working on: valid


processing chunks:   0%|          | 0/2536 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import wandb
from spot.model import ModelTrainingArgs

r1_train_args = ModelTrainingArgs(
    train_batch_size=38 // 8,
    eval_batch_size=256 // 8,
    max_epochs=3,
)

if train_r1:
    r1_chunks: dict[str, ChunkedDataset] = {}
    with run_long_task("Preparing R1 chunked datasets", notify=False):
        for n in ["valid", "train"]:
            r1_chunks[n] = r1_src_datasets[n].to_chunks(
                tokenizer, r1_wrapper.args.ctx_args, max_workers=20
            )

    r1_trainer = r1_wrapper.build_trainer(
        datadir / "checkpoints" / r1_model_name,
        r1_train_args,
        dataset=r1_chunks["train"].data,
        eval_dataset=r1_chunks["valid"].data,
    )

    wandb.init(
        project=r1_model_name,
        dir=str(datadir),
        config={"r1_decoding_args": r1_args, "r1_train_args": r1_train_args},
    )

    with run_long_task(f"Training {r1_model_name}"):
        init_perf = r1_trainer.evaluate(max_length=r1_args.generation_max_length)
        print("initial performance:", init_perf)
        r1_trainer.train()

    wandb.log({"time_stats": r1_monitor.timer.total_times()})

    final_perf = r1_trainer.evaluate(max_length=r1_args.generation_max_length)
    print("final performance:", final_perf)
    wandb.finish()

    r1_wrapper.save_pretrained(datadir / f"checkpoints/saved/{r1_model_name}")


PyTorch: setting up devices
Using amp half precision backend


In [None]:
from spot.data import preds_to_accuracies, pretty_print_accuracies
from spot.visualization import display_code_sequence, visualize_batch

r1_wrapper_test = r1_wrapper
r1_accs, r1_data, r1_preds = r1_wrapper_test.eval_on_dataset(r1_src_datasets["test"])
pretty_print_accuracies(r1_accs)

display_code_sequence(
    [
        visualize_batch(
            r1_data,
            i,
            r1_preds,
            tokenizer,
            r1_wrapper_test.args.ctx_args,
        )
        for i in range(min(16, len(r1_data.chunks_info)))
    ]
)


processing chunks:   0%|          | 0/26 [00:00<?, ?it/s]

predict:   0%|          | 0/15 [00:00<?, ?it/s]

partial_acc: 75.00%
partial_acc_wo_any: 75.00%
partial_accs:
   FuncArg: 75.00%
   FuncReturn: 78.57%
   LocalVar: 50.00%
full_acc: 57.50%
full_accs:
   FuncArg: 50.00%
   FuncReturn: 71.43%
   LocalVar: 50.00%
n_labels: 40


Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…