In [7]:
import random
import sys


import torch
from optimum.bettertransformer import BetterTransformer
from transformers import RobertaConfig, RobertaForMaskedLM

from rl.tokenizer import ASTTokenizer
from utils.logging import setup_logging

sys.setrecursionlimit(20000)
setup_logging()
MAX_FRAGMENT_SEQ_LEN = 512  # Maximum length of the AST fragment sequence

PAD_TOKEN = "<pad>"
CLS_TOKEN = "<s>"
SEP_TOKEN = "</s>"
MASK_TOKEN = "<mask>"
UNK_TOKEN = "<unk>"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
import pickle
import tqdm

with open("../ASTBERTa/frag_data.pkl", "rb") as f:
    frag_data = pickle.load(f)

with open("../ASTBERTa/vocab_data.pkl", "rb") as f:
    vocab_data = pickle.load(f)


frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]

vocab = vocab_data["vocab"]
token_to_id = vocab_data["token_to_id"]
id_to_token = vocab_data["id_to_token"]
special_token_ids = vocab_data["special_token_ids"]

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [9]:
from js_ast.fragmentise import hash_frag


def tokenize(frag_seq):
    frag_id_seq: list[int] = []
    frag_id_seq.append(token_to_id[CLS_TOKEN])

    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        if frag_hash in token_to_id:
            frag_id_seq.append(token_to_id[frag_hash])
        else:
            oov_frag: dict[str, str] = {"type": frag["type"]}
            oov_frag_hash = hash_frag(oov_frag)
            if oov_frag_hash in token_to_id:
                frag_id_seq.append(token_to_id[oov_frag_hash])
            else:
                print(f"UNK_TOKEN: {frag_hash}")
                frag_id_seq.append(token_to_id[UNK_TOKEN])

        if len(frag_id_seq) >= MAX_FRAGMENT_SEQ_LEN:
            break

    if len(frag_id_seq) < MAX_FRAGMENT_SEQ_LEN:
        frag_id_seq.append(token_to_id[SEP_TOKEN])

    random_start_idx = random.randint(1, len(frag_id_seq) - 1)
    frag_id_seq = [token_to_id[CLS_TOKEN]] + frag_id_seq[
        random_start_idx : random_start_idx + MAX_FRAGMENT_SEQ_LEN - 1
    ]

    return torch.tensor([frag_id_seq], dtype=torch.long)

In [24]:
data = []

for i, seq in (
    bar := tqdm.tqdm(enumerate(frag_seqs[:1000]), total=len(frag_seqs[:1000]))
):
    labels = tokenize(seq).to(device)
    inputs = labels.clone()
    attention_mask = torch.ones_like(inputs, device=device)

    probability_matrix = torch.full(labels.shape, 0.15)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100
    inputs[masked_indices] = token_to_id[MASK_TOKEN]

    data.append((inputs, attention_mask, labels))

100%|██████████| 1000/1000 [00:00<00:00, 2283.53it/s]


In [25]:
with open("data.pkl", "wb") as f:
    pickle.dump(data, f)

In [41]:
vocab_size = len(vocab)  # size of vocabulary
intermediate_size = 2048  # embedding dimension
hidden_size = 512

num_hidden_layers = 3
num_attention_heads = 8
dropout = 0

config = RobertaConfig(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size,
    hidden_dropout_prob=dropout,
    max_position_embeddings=MAX_FRAGMENT_SEQ_LEN + 2,
)


# # Load the ASTBERTa model
tokenizer = ASTTokenizer(vocab, token_to_id, MAX_FRAGMENT_SEQ_LEN, device)
pretrained_model = torch.load(
    "../ASTBERTa/models/2023-06-20T14:44:.514456/model_8500.pt"
)
# pretrained_model = torch.load("../ASTBERTa/models/final/model_27500.pt")


if isinstance(pretrained_model, torch.nn.DataParallel):
    pretrained_model = pretrained_model.module

# ast_net = RobertaForMaskedLM.from_pretrained(
#     "../ASTBERTa/models/new/chimport sklearn


ast_net = RobertaForMaskedLM.from_pretrained(
    pretrained_model_name_or_path=None,
    state_dict=pretrained_model.state_dict(),
    config=config,
).to(device)

ast_net = BetterTransformer.transform(ast_net)

In [42]:
import sklearn.metrics
import tqdm
from torchmetrics.classification import MulticlassF1Score, Accuracy

losses = []

l1_loss = torch.nn.L1Loss()

for inputs, attention_mask, labels in (bar := tqdm.tqdm(data, total=len(data))):
    out = ast_net(input_ids=inputs, attention_mask=attention_mask, labels=labels)

    loss = out.loss
    if torch.isnan(loss):
        print("NaN loss")
        continue

    losses.append(loss.item())

    bar.set_postfix({"loss": sum(losses) / len(losses)})

  6%|▋         | 64/1000 [00:00<00:02, 313.46it/s, loss=3.74]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 14%|█▎        | 136/1000 [00:00<00:02, 342.22it/s, loss=3.78]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 21%|██        | 212/1000 [00:00<00:02, 360.35it/s, loss=3.48]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 30%|██▉       | 295/1000 [00:00<00:01, 389.68it/s, loss=3.32]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 37%|███▋      | 373/1000 [00:01<00:01, 379.74it/s, loss=3.25]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 45%|████▌     | 450/1000 [00:01<00:01, 363.77it/s, loss=3.2] 

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 52%|█████▏    | 524/1000 [00:01<00:01, 362.53it/s, loss=3.19]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 64%|██████▍   | 643/1000 [00:01<00:00, 378.43it/s, loss=3.2] 

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 68%|██████▊   | 681/1000 [00:01<00:00, 378.17it/s, loss=3.23]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 80%|███████▉  | 797/1000 [00:02<00:00, 377.22it/s, loss=3.28]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 88%|████████▊ | 876/1000 [00:02<00:00, 381.96it/s, loss=3.32]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


 95%|█████████▌| 953/1000 [00:02<00:00, 365.37it/s, loss=3.41]

NaN loss
NaN loss
NaN loss
NaN loss
NaN loss
NaN loss


100%|██████████| 1000/1000 [00:02<00:00, 366.64it/s, loss=3.41]


NaN loss
NaN loss


In [43]:
sum(losses) / len(losses)

3.406374173677297