In [1]:
%load_ext autoreload
%autoreload 2

from coeditor.common import *
import os

os.chdir(proj_root())

In [2]:
from coeditor.encoding import encode_basic, decode_tokens
import torch
from coeditor.retrieval_model import (
    RetrievalEditorModel,
    T5LayerSelfAttention,
    t5_cross_attention,
    T5Stack,
    encode_query_stack,
)


In [3]:
model = RetrievalEditorModel.from_code_t5("base")
device = torch.device("cuda:2")
model.to(device)
model.eval()
print(f"num_parameters: {model.num_parameters()/1e6:.2f}M")


num_parameters: 222.88M


In [4]:
# test self attention
config = model.config
sa = T5LayerSelfAttention(config, has_relative_attention_bias=True)
sa = sa.to(device)
sa.eval()

hidden_states = torch.randn(2, 3, config.d_model).to(device)
out1 = sa.forward(hidden_states)[0]
out2 = t5_cross_attention(sa, hidden_states, key_value_states=hidden_states)[0]

torch.all(out1 == out2)


tensor(True, device='cuda:2')

In [5]:
stack = model.encoder
stack.eval()

query_ids = torch.LongTensor([[1, 2, 5], [8,3,0]]).to(device)
query_mask = query_ids.ne(0)
ref_states = tuple(torch.randn(2, 5, config.d_model).to(device) for _ in stack.block)


In [6]:
ref_mask = torch.zeros(2, 5, dtype=torch.bool).to(device)
out1 = stack.forward(query_ids, attention_mask=query_mask)[0]
out2 = encode_query_stack(
    stack,
    query_ids,
    ref_states,
    ref_attention_mask=ref_mask,
).last_hidden_state

print(f"{out1.dtype=}")
torch.max(torch.abs((out1 - out2) * query_mask.unsqueeze(-1))) < 1e-5


out1.dtype=torch.float32


tensor(True, device='cuda:2')

In [15]:

query = [
    "<s>assert weather == <extra_id_0>\n</s>",
    "<s>assert time == <extra_id_0> # make this longer\n</s>",
    "<s>assert name == <extra_id_0>\n</s>",
]
good_refs = ["<s>weather = 'Icey'\n</s>", "<s>time = '1:25AM'\n</s>", "<s>name = 'Tako'\n</s>"]
bad_refs = ["<s>weather = 'Sunny'\n</s>", "<s>time = '5:21PM'\n</s>", "<s>name = 'Shmi'\n</s>"]
answer = [
    "<pad><s><extra_id_0>'Icey'",
    "<pad><s><extra_id_0>'1:25AM'",
    "<pad><s><extra_id_0>'Tako'",
]

model.eval()

for q_id in [0]:
    q_ids = slice(0, 2)
    print(f"{q_id=}")
    for query_attened_ref in [False, True]:
        print(f"{query_attened_ref=}")
        model.query_attened_ref = query_attened_ref

        out = model.forward(
            model.encode_token_seqs(query[q_ids]),
            references=[encode_basic(x) for x in good_refs],
            labels=model.encode_token_seqs(answer[q_ids], -100),
        )
        print("Loss with good ref:", out.loss.item())
        
        with torch.autocast("cuda"):
            out = model.forward(
                model.encode_token_seqs(query[q_ids]),
                references=[encode_basic(x) for x in reversed(good_refs)],
                # query_ref_list=[[1, 0], []],
                labels=model.encode_token_seqs(answer[q_ids], -100),
            )
            print("Loss with reversed good ref:", out.loss.item())

        out = model.forward(
            model.encode_token_seqs(query[q_ids]),
            references=[encode_basic(x) for x in bad_refs],
            labels=model.encode_token_seqs(answer[q_ids], -100),
        )
        print("Loss with bad ref:", out.loss.item())

        out = model.forward(
            model.encode_token_seqs(query[q_ids]),
            references=None,
            labels=model.encode_token_seqs(answer[q_ids], -100),
        )
        print("Loss with no ref:", out.loss.item())


q_id=0
query_attened_ref=False
Loss with good ref: 2.984902858734131
Loss with reversed good ref: 2.9848129749298096
Loss with bad ref: 4.678882122039795
Loss with no ref: 5.392837047576904
query_attened_ref=True
Loss with good ref: 2.831233263015747
Loss with reversed good ref: 2.8314459323883057
Loss with bad ref: 4.534156799316406
Loss with no ref: 5.392837047576904


In [15]:
from coeditor.model import CodeT5Model

single_inputs = ["".join([*good_refs, q]) for q in query]
print("Single input:")
print("\n-------\n".join(single_inputs))

codet5 = cast(CodeT5Model, CodeT5Model.from_pretrained("Salesforce/codet5-base"))
codet5.to(device)
codet5.eval()

out = codet5.forward(
    model.encode_token_seqs(single_inputs),
    labels=model.encode_token_seqs(answer),
)
print("Loss of CodeT5:", out.loss)


Single input:
<s>weather = 'Icey'
</s><s>time = '1:25AM'
</s><s>name = 'Tako'
</s><s>assert weather == <extra_id_0>
</s>
-------
<s>weather = 'Icey'
</s><s>time = '1:25AM'
</s><s>name = 'Tako'
</s><s>assert time == <extra_id_0> # make this longer
</s>
-------
<s>weather = 'Icey'
</s><s>time = '1:25AM'
</s><s>name = 'Tako'
</s><s>assert name == <extra_id_0>
</s>
Loss of CodeT5: tensor(3.4302, device='cuda:2', grad_fn=<NllLossBackward0>)


In [16]:
codet5_seq = codet5.generate(
    model.encode_token_seqs(single_inputs),
    max_length=50,
    num_beams=8,
)
for i, y in enumerate(codet5_seq):
    print(f"Output {i}:")
    print(decode_tokens(y))

Output 0:
<pad><s><extra_id_0>Tako <s> public class TakoWeather {</s>
Output 1:
<pad><s><extra_id_0>'1:25AM' name = 'Tako'</s>
Output 2:
<pad><s><extra_id_0>Tako <s> public class TakoWeather</s><pad>


In [17]:
model.query_attened_ref = True
out_seq = model.generate(
    model.encode_token_seqs(query),
    references=[encode_basic(x) for x in reversed(good_refs)],
    # num_beams=8,
    max_length=50,
)
for i, y in enumerate(out_seq):
    print(f"Output {i}:")
    print(decode_tokens(y))


Output 0:
<pad><s><extra_id_0>'Tako' unction( c ount)</s><pad><pad><pad><pad><pad>
Output 1:
<pad><s><extra_id_0>'1:25AM' ata = '1:25AM' ata</s>
Output 2:
<pad><s><extra_id_0>'Tako' unction s tarts=0 unction</s><pad><pad>


In [3]:
from coeditor.history import *

history = get_commit_history(proj_root(), 25)
for cinfo in history[:10]:
    print(cinfo.msg)

# edits = edits_from_commit_history(proj_root(), history)


Implement file-level dataset creation.
Implement file-level edit encoder.
Update installation instructions.
Add DevGuide.md.
Implement encoding format for CodeT5.
Line-diff-based format for encoding edits.
Implement EditSelectors.
Improve diff visualization.
Improve edit context construction.
- Bugfix: `from_code_changes` mutates original copy. - Collect only usees in editing ctx. - Improve editing visualization.


In [None]:
from coeditor.dataset import dataset_from_projects

dataset = dataset_from_projects([proj_root()])


In [5]:
dataset.edits[0].print()


-------------------- Training Example: spot.type_checking/collect_annotations --------------------
Input:
    from dataclasses import dataclass
    from distutils.log import error
     <add> from logging import warn
    from posixpath import dirname, realpath
    import re
    from typing import Iterable
    from.utils import *
    import subprocess


     <add> @dataclass(order=True, unsafe_hash=True)
     <del> @dataclass(frozen=True)
    class AnnotPath:

        value: Tuple[str,...]

     <add>     def __repr__(self):
     <add>         return f"AnnotPath('{'.'.join(self.value)}')"
     <add> 
     <add>     def __str__(self):
     <add>         return f"'{'.'.join(self.value)}'"
     <add> 
     <add> 
    def annot_path(*segs: str) -> AnnotPath:
        return AnnotPath(tuple(segs))
     <add> 

    <extra_id_0>def collect_annotations(code: cst.CSTNode) -> Dict[AnnotPath, Optional[cst.Annotation]]:
    <extra_id_1>    collector = AnnotCollector()
    <extra_id_2>    code.visit(c

Retriving edits:   0%|          | 0/1 [00:00<?, ?it/s]Starting task: Retriving initial project from commit: 3c17c4ea794ce495edd62698a46aa696e384bed1
(0.1s) Finished task: Retriving initial project from commit: 3c17c4ea794ce495edd62698a46aa696e384bed1
Edits from commits: 100%|██████████| 215/215 [03:32<00:00,  1.01it/s]
Retriving edits: 100%|██████████| 1/1 [05:24<00:00, 324.13s/it]
Encoding edits: 100%|██████████| 215/215 [05:21<00:00,  1.50s/it]

In [3]:
from coeditor.history import *
from coeditor.encoding import *

all_mods = [c for e in edits for c in e.all_elem_changes() if isinstance(c, Modified)]
c = all_mods[0]
print(show_change(c))


* Modified: 
    def preds_to_accuracies(
        preds: Sequence[Sequence[PythonType]],
        dataset: ChunkedDataset,
        metric: AccuracyMetric,
    ):
            cats = [an.cat for info in dataset.chunks_info for an in info.annots_info]
            labels = [ty for info in dataset.chunks_info for ty in info.types]
    -       poses = [i for info in dataset.chunks_info for i in info.label_ids]
            return type_accuracies(
                list(seq_flatten(preds)),
                labels,
                cats,
    -           poses,
                metric=metric,
            )


In [4]:
from coeditor.history import *

n_add = n_del = n_mod = 0
for e in edits:
    for c in e.all_elem_changes():
        if isinstance(c, Added):
            n_add += 1
        elif isinstance(c, Deleted):
            n_del += 1
        elif isinstance(c, Modified):
            n_mod += 1
        else:
            raise ValueError(c)
print("n_commit:", len(edits))
print(f"n_add: {n_add}")
print(f"n_del: {n_del}")
print(f"n_mod: {n_mod}")


n_commit: 49
n_add: 339
n_del: 246
n_mod: 240


In [5]:
analyzed_edits = analyze_edits(edits)


Starting task: Performing intial module-level analysis...
(6.4s) Finished task: Performing intial module-level analysis...


Analyzing edits: 100%|██████████| 49/49 [03:52<00:00,  4.75s/it]


Unnamed: 0,name,count,avg_time,total_time
1,UsageAnalysis,98,1.579122,154.753982
2,ModuleAnlaysis/Incremental,182,0.412723,75.115604
0,ModuleAnlaysis/Initial,1,6.426628,6.426628
3,_select_change_ctx,240,5e-06,0.001288


In [27]:
selected, all_cedits = select_edits(
    analyzed_edits, EditSelectors.api_change_to_callsite
)
coverage = set[tuple[ProjectPath, str]]()

out_file = Path("output/api_change_to_callsite.txt")
with open(out_file, "w") as f:
    for ce in selected:
        for c in ce.grouped_ctx_changes["users"]:
            coverage.add((get_change_path(c), not_none(ce.commit_info).hash))

        ce.pprint(file=f)
        print("~" * 50, "\n", file=f)

print("All modifications:", len(all_cedits))
print("User changes:", len(coverage))
print("Coverage:", f"{len(coverage) / len(all_cedits):.1%}")


All modifications: 240
User changes: 29
Coverage: 12.1%


In [6]:
selected2, all_cedits2 = select_edits(
    analyzed_edits, EditSelectors.usee_changes_to_user
)

out_file = Path("output/pretrain.txt")
with open(out_file, "w") as f:
    for ce in selected2:
        ce.pprint(file=f)
        print("~" * 50, "\n", file=f)

print("All modifications:", len(all_cedits2))
print("User changes:", len(selected2))
print("Coverage:", f"{len(selected2) / len(all_cedits2):.1%}")


All modifications: 240
User changes: 156
Coverage: 65.0%


In [None]:
# ==== End of new contents ====


In [None]:
dataset = "ManyTypes4Py"

result_paths = {
    "CodeT5": get_eval_dir(dataset, ""),
    "TypeT5": get_eval_dir(
        dataset,
        "(implicit_imports, new) model-v7--TrainingConfig(drop_env_types=False, add_implicit_rel_imports=True)",
    ),
}


In [None]:
ex_proj = PythonProject.from_root(Path("/home/jiayi/Projects/type4py"))
analysis = UsageAnalysis(
    ex_proj, add_implicit_rel_imports=True, add_override_usages=True
)
pretty_print_dict(analysis.get_stats())


In [None]:
from spot.data import (
    create_tokenized_srcsets,
    get_tk_dataset_name,
    load_tokenized_srcsets,
    TypeCheckSettings,
)
from spot.tokenized_src import PreprocessArgs

pre_args = PreprocessArgs()
dataset = "InferTypes4Py"
sdata_name = get_tk_dataset_name(dataset, pre_args, False)
sdata_path = get_dataroot() / "TokenizedSrcSets" / sdata_name
create_tokenized_srcsets(
    dataset,
    sdata_path,
    func_only=False,
    pre_args=pre_args,
)
tk_dataset = load_tokenized_srcsets(sdata_path)
tk_dataset["test"].print_stats()


In [None]:
from spot import proj_root
from spot.static_analysis import ProjectPath, UsageAnalysis, PythonProject
from pprint import pprint


proj = PythonProject.from_root(proj_root())
for caller, callees in UsageAnalysis(proj).user2used.items():
    if caller.module == "spot.static_analysis":
        print(caller)
        for callee in callees:
            print("\t", callee.used, "" if callee.is_certain else "  (maybe)")


In [None]:
import libcst as cst

from spot.tokenized_src import TokenizedSrc, PreprocessArgs
from spot.utils import Path, decode_tokens

ex_code = '''# document comment 1
  # document comment 2
"""String document commnet"""
import os; import spot;
from sys import argv, exit
# after import
@wraps(function)
def catch_permission_denied(function):
    import some.inner.imports
    """
    Decorator to catch :class:`psycopg2.ProgrammingError` exceptions with the
    ``INSUFFICIENT_PRIVILEGE`` error code and rethrow them as
    :class:`~werkzeug.exceptions.Forbidden` exceptions instead.
    """
    @wraps(function)
    def decorated(x: str, y: int) -> str:
        try:
            # comment 1
            # comment 1 cont
            return function(*args, **kwargs)

        except InsufficientPrivilege as error:
            LOG.error("Forbidden: %s", error) # comment 2
            raise Forbidden()

    return decorated
'''
pre_args = PreprocessArgs(stub_in_preamble=True)
ex_src = TokenizedSrc.parse(ex_code, Path("test_file"), Path("test_repo"), pre_args)
print(decode_tokens(ex_src.tokenized_code))


In [None]:
from spot.data import src_to_chunks_, CtxArgs, PreprocessArgs
from ipywidgets import interactive

pre_args = PreprocessArgs(stub_in_preamble=True)
ex_src = TokenizedSrc.parse(ex_code, Path("test_file"), Path("test_repo"), pre_args)


def print_code(
    preamble: int,
    left: int,
    right: int,
    ctx_size: int,
    max_labels: int,
    chunk_id: int,
    inline_prev: bool,
):
    chunks = []
    args = CtxArgs(
        ctx_size,
        preamble,
        left,
        right,
        max_labels=max_labels,
        inline_prev_gold=inline_prev,
    )
    src_to_chunks_(chunks, [], ex_src, (0, len(ex_src.types)), args)
    print(decode_tokens(chunks[chunk_id]["input_ids"]))


interactive(
    print_code,
    preamble=(1, 100),
    left=(1, 200),
    right=(1, 100),
    ctx_size=(1, 500),
    max_labels=(1, 10),
    chunk_id=(0, 1),
    inline_prev=True,
)
