In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import SrcDataset, TokenizedSrc, get_model_name, get_dataset_name, load_src_datasets
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper

quicktest = False
drop_comments = True
data_reduction = 1
train_R1 = False
check_in_isolation = False
all_labels = True
max_tokens_per_file = 4096


ctx_args = CtxArgs(
    ctx_size=max_tokens_per_file,
    left_margin=2048,
    right_margin=1024,
    max_labels=32,
)

dec_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=ctx_args,
    max_workers=20,
)


datasets_name = get_dataset_name(
    drop_comments=drop_comments,
    all_labels=all_labels,
)

r0_model_name = get_model_name(
    datasets_name,
    ctx_args=ctx_args,
    data_reduction=data_reduction,
    quicktest=quicktest,
)

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=quicktest,
)


In [3]:
# train the model
from spot.train import ModelTrainingArgs, train_spot_model

train_args = ModelTrainingArgs(
    train_max_tokens=max_tokens_per_file,
    eval_max_tokens=2 * max_tokens_per_file,
    max_epochs=3,
    check_in_isolation=check_in_isolation,
)

r0_wrapper, r0_extra = train_spot_model(
    src_datasets,
    r0_model_name,
    gpus=[1],  # training with GPU 1
    dec_args=dec_args,
    train_args=train_args,
    record_batches=train_R1,
    quicktest=quicktest,
    use_small_model=False,
)


  warn(f"Failed to load image Python extension: {e}")


chunk_srcs_per_file:   0%|          | 0/1087 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/1569 [00:00<?, ?it/s]

chunk_srcs_per_file:   0%|          | 0/933 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/1260 [00:00<?, ?it/s]

chunk_srcs_per_file:   0%|          | 0/16281 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/22834 [00:00<?, ?it/s]

Pushover: (Finished: 'Preparing chunked datasets'.) Time taken: 20.1s


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
445.764   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

                not been set for this class (_ResultMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# load trained model
from spot.utils import pickle_load, pickle_dump

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
if train_R1:
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
r0_wrapper.args.do_sample = False
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(left=2048, window=1024, right=1024, max_labels=32), sampling_max_tokens=32768, max_workers=20, max_tokens_per_type=10, do_sample=False, top_p=0.9, num_beams=None)


In [None]:
# model evaluation

import plotly.express as px

from spot.train import evaluate_model, visualize_accuracies

r0_eval = evaluate_model(
    r0_wrapper,
    None,
    r0_model_name,
    src_datasets["test"],
    datadir=datadir,
    check_in_isolation=check_in_isolation,
    reeval=False,
)
visualize_accuracies(r0_eval)


chunk_srcs_per_file:   0%|          | 0/11 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/18 [00:00<?, ?it/s]

predict:   0%|          | 0/18 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-(2048, 1024, 1024, 32)--src_datasets-all_labels-drop_comments/eval/r0_eval-DecodingArgs(ctx_args=CtxArgs(left=2048, window=1024, right=1024, max_labels=32), sampling_max_tokens=32768, max_workers=20, max_tokens_per_type=10, do_sample=False, top_p=0.9, num_beams=None).pkl'


interactive(children=(IntSlider(value=0, description='round', max=0), Checkbox(value=True, description='expand…