In [1]:
%load_ext autoreload
%autoreload 2

from coeditor.common import *
import os

from coeditor.encoding import encode_basic, decode_tokens
import torch
from coeditor.retrieval_model import (
    RetrievalEditorModel,
    T5LayerSelfAttention,
    t5_cross_attention,
    T5Stack,
    encode_query_stack,
    AttentionMode,
)

os.chdir(proj_root())

In [2]:
model = RetrievalEditorModel.from_code_t5("base")
device = torch.device("cuda:2")
model.to(device)
model.eval()
print(f"num_parameters: {model.num_parameters()/1e6:.2f}M")

num_parameters: 222.88M


In [3]:
# test self attention
config = model.config
sa = T5LayerSelfAttention(config, has_relative_attention_bias=True)
sa = sa.to(device)
sa.eval()

hidden_states = torch.randn(2, 3, config.d_model).to(device)
out1 = sa.forward(hidden_states)[0]
out2 = t5_cross_attention(sa, hidden_states, key_value_states=hidden_states)[0]

torch.all(out1 == out2)


tensor(True, device='cuda:2')

In [4]:
stack = model.encoder
stack.eval()

query_ids = torch.LongTensor([[1, 2, 5], [8, 3, 0]]).to(device)
query_mask = query_ids.ne(0)
ref_states = tuple(torch.randn(2, 5, config.d_model).to(device) for _ in stack.block)


ref_mask = torch.zeros(2, 5, dtype=torch.bool).to(device)
out1 = stack.forward(query_ids, attention_mask=query_mask)[0]
out2 = encode_query_stack(
    stack,
    query_ids,
    ref_states,
    ref_attention_mask=ref_mask,
).last_hidden_state

print(f"{out1.dtype=}")
torch.max(torch.abs((out1 - out2) * query_mask.unsqueeze(-1))) < 1e-5

out1.dtype=torch.float32


tensor(True, device='cuda:2')

In [34]:
query = [
    "<s>assert weather == <extra_id_0>\n</s>",
    "<s>assert time == <extra_id_0> # make this longer\n</s>",
    "<s>assert name == <extra_id_0>\n</s>",
]
good_refs = [
    "<s>weather = 'Icey'\n</s>",
    "<s>time = '1:25AM'\n</s>",
    "<s>name = 'Tako'\n</s>",
]
bad_refs = [
    "<s>weather = 'Sunny'\n</s>",
    "<s>time = '5:21PM'\n</s>",
    "<s>name = 'Shmi'\n</s>",
]
answer = [
    "<pad><s><extra_id_0>'Icey'",
    "<pad><s><extra_id_0>'1:25AM'",
    "<pad><s><extra_id_0>'Tako'",
]

model.eval()

for q_id in [0]:
    q_ids = slice(0, 2)
    print(f"{q_id=}")
    for attention_mode in [AttentionMode.basic, AttentionMode.query2ref, AttentionMode.bidirectional]:
        print(f"{attention_mode=}")
        model.attention_mode = attention_mode

        out = model.forward(
            model.encode_token_seqs(query[q_ids]),
            references=[encode_basic(x) for x in good_refs],
            labels=model.encode_token_seqs(answer[q_ids], -100),
        )
        print("Loss with good ref:", out.loss.item())

        with torch.autocast("cuda"):
            out = model.forward(
                model.encode_token_seqs(query[q_ids]),
                references=[encode_basic(x) for x in reversed(good_refs)],
                # query_ref_list=[[1, 0], []],
                labels=model.encode_token_seqs(answer[q_ids], -100),
            )
            print("Loss with reversed good ref:", out.loss.item())

        out = model.forward(
            model.encode_token_seqs(query[q_ids]),
            references=[encode_basic(x) for x in bad_refs],
            labels=model.encode_token_seqs(answer[q_ids], -100),
        )
        print("Loss with bad ref:", out.loss.item())

        out = model.forward(
            model.encode_token_seqs(query[q_ids]),
            references=None,
            labels=model.encode_token_seqs(answer[q_ids], -100),
        )
        print("Loss with no ref:", out.loss.item())


q_id=0
attention_mode=<AttentionMode.basic: 1>
Loss with good ref: 2.9849116802215576
Loss with reversed good ref: 2.9848129749298096
Loss with bad ref: 4.67891263961792
Loss with no ref: 5.39285135269165
attention_mode=<AttentionMode.query2ref: 2>
Loss with good ref: 2.831239700317383
Loss with reversed good ref: 2.831474781036377
Loss with bad ref: 4.534174919128418
Loss with no ref: 5.39285135269165
attention_mode=<AttentionMode.bidirectional: 3>
Loss with good ref: 2.4803545475006104
Loss with reversed good ref: 2.480644702911377
Loss with bad ref: 4.176389217376709
Loss with no ref: 5.392851829528809


In [12]:
from coeditor.model import CodeT5Model

single_inputs = ["".join([*good_refs, q]) for q in query]
print("Single input:")
print("\n-------\n".join(single_inputs))

codet5 = cast(CodeT5Model, CodeT5Model.from_pretrained("Salesforce/codet5-base"))
codet5.to(device)
codet5.eval()

out = codet5.forward(
    model.encode_token_seqs(single_inputs),
    labels=model.encode_token_seqs(answer),
)
print("Loss of CodeT5:", out.loss)


Single input:
<s>weather = 'Icey'
</s><s>time = '1:25AM'
</s><s>name = 'Tako'
</s><s>assert weather == <extra_id_0>
</s>
-------
<s>weather = 'Icey'
</s><s>time = '1:25AM'
</s><s>name = 'Tako'
</s><s>assert time == <extra_id_0> # make this longer
</s>
-------
<s>weather = 'Icey'
</s><s>time = '1:25AM'
</s><s>name = 'Tako'
</s><s>assert name == <extra_id_0>
</s>
Loss of CodeT5: tensor(3.4302, device='cuda:2', grad_fn=<NllLossBackward0>)


In [13]:
codet5_seq = codet5.generate(
    model.encode_token_seqs(single_inputs),
    max_length=50,
    num_beams=8,
)
for i, y in enumerate(codet5_seq):
    print(f"Output {i}:")
    print(decode_tokens(y))

model.attention_mode = AttentionMode.bidirectional
out_seq = model.generate(
    model.encode_token_seqs(query),
    references=[encode_basic(x) for x in reversed(good_refs)],
    # num_beams=8,
    max_length=50,
)
for i, y in enumerate(out_seq):
    print(f"Output {i}:")
    print(decode_tokens(y))


Output 0:
<pad><s><extra_id_0>Tako <s> public class TakoWeather {</s>
Output 1:
<pad><s><extra_id_0>'1:25AM' name = 'Tako'</s>
Output 2:
<pad><s><extra_id_0>Tako <s> public class TakoWeather</s><pad>
Output 0:
<pad><s><extra_id_0>'Icey' ata = 'Icey' ata=</s><pad><pad><pad><pad><pad><pad>
Output 1:
<pad><s><extra_id_0>'1:25AM'
 = new Tako(weather)
.name=weather</s>
Output 2:
<pad><s><extra_id_0>'Tako'
 = 'Tako' unction</s><pad><pad><pad><pad><pad><pad><pad><pad>
