In [21]:
import random
import sys


import torch
from optimum.bettertransformer import BetterTransformer
from transformers import RobertaConfig, RobertaForMaskedLM

from rl.tokenizer import ASTTokenizer
from utils.logging import setup_logging

sys.setrecursionlimit(20000)
setup_logging()
MAX_FRAGMENT_SEQ_LEN = 512  # Maximum length of the AST fragment sequence

PAD_TOKEN = "<pad>"
CLS_TOKEN = "<s>"
SEP_TOKEN = "</s>"
MASK_TOKEN = "<mask>"
UNK_TOKEN = "<unk>"

In [2]:
import pickle
import tqdm

with open("../ASTBERTa/frag_data.pkl", "rb") as f:
    frag_data = pickle.load(f)

with open("../ASTBERTa/vocab_data.pkl", "rb") as f:
    vocab_data = pickle.load(f)


frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]

vocab = vocab_data["vocab"]
token_to_id = vocab_data["token_to_id"]
id_to_token = vocab_data["id_to_token"]
special_token_ids = vocab_data["special_token_ids"]

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vocab_size = len(vocab)  # size of vocabulary
intermediate_size = 2048  # embedding dimension
hidden_size = 512

num_hidden_layers = 3
num_attention_heads = 8
dropout = 0.1

config = RobertaConfig(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size,
    hidden_dropout_prob=dropout,
    max_position_embeddings=MAX_FRAGMENT_SEQ_LEN + 2,
)


# # Load the ASTBERTa model
tokenizer = ASTTokenizer(vocab, token_to_id, MAX_FRAGMENT_SEQ_LEN, device)
pretrained_model = torch.load("../ASTBERTa/models/small_vocab/model_17500.pt")
# pretrained_model = torch.load("../ASTBERTa/models/final/model_27500.pt")


if isinstance(pretrained_model, torch.nn.DataParallel):
    pretrained_model = pretrained_model.module

# ast_net = RobertaForMaskedLM.from_pretrained(
#     "../ASTBERTa/models/new/checkpoint-35000"
# ).to(device)

ast_net = RobertaForMaskedLM.from_pretrained(
    pretrained_model_name_or_path=None,
    state_dict=pretrained_model.state_dict(),
    config=config,
).to(device)

ast_net = BetterTransformer.transform(ast_net)

In [18]:
from js_ast.fragmentise import hash_frag


def tokenize(frag_seq):
    frag_id_seq: list[int] = []
    frag_id_seq.append(token_to_id[CLS_TOKEN])

    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        if frag_hash in token_to_id:
            frag_id_seq.append(token_to_id[frag_hash])
        else:
            oov_frag: dict[str, str] = {"type": frag["type"]}
            oov_frag_hash = hash_frag(oov_frag)
            if oov_frag_hash in token_to_id:
                frag_id_seq.append(token_to_id[oov_frag_hash])
            else:
                print(f"UNK_TOKEN: {frag_hash}")
                frag_id_seq.append(token_to_id[UNK_TOKEN])

        if len(frag_id_seq) >= MAX_FRAGMENT_SEQ_LEN:
            break

    if len(frag_id_seq) < MAX_FRAGMENT_SEQ_LEN:
        frag_id_seq.append(token_to_id[SEP_TOKEN])

    random_start_idx = random.randint(1, len(frag_id_seq) - 1)
    frag_id_seq = [token_to_id[CLS_TOKEN]] + frag_id_seq[
        random_start_idx : random_start_idx + MAX_FRAGMENT_SEQ_LEN - 1
    ]

    return torch.tensor([frag_id_seq], dtype=torch.long)

In [19]:
accs = []

for i, seq in (bar := tqdm.tqdm(enumerate(frag_seqs), total=len(frag_seqs))):
    labels = tokenize(seq).to(device)
    inputs = labels.clone()
    attention_mask = torch.ones_like(inputs, device=device)

    probability_matrix = torch.full(labels.shape, 0.15)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100
    inputs[masked_indices] = token_to_id[MASK_TOKEN]

    out = ast_net(input_ids=inputs, attention_mask=attention_mask, labels=labels)

    preds = out.logits.argmax(dim=-1)
    acc = (labels[masked_indices] == preds[masked_indices]).sum() / masked_indices.sum()

    if not torch.isnan(acc):
        accs.append(acc.item())

    if i % 100 == 0:
        bar.set_postfix({"acc": sum(accs) / len(accs)})
        acc = []

100%|██████████| 14017/14017 [01:04<00:00, 216.26it/s, acc=0.709]


In [20]:
print(sum(accs) / len(accs))

0.7090650768758162
