In [1]:
# Initial coverage: 14.73665% Final coverage: 14.78238%
import json
import logging
import os
import pickle
import random
import sys
import traceback
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import tqdm
from optimum.bettertransformer import BetterTransformer
from torch import optim
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import RobertaConfig, RobertaForMaskedLM

from rl.dqn import DQN, ReplayMemory
from rl.env import FuzzingEnv
from rl.fuzzing_action import FuzzingAction
from rl.tokenizer import ASTTokenizer
from utils.logging import setup_logging

sys.setrecursionlimit(20000)
setup_logging()
MAX_FRAGMENT_SEQ_LEN = 512  # Maximum length of the AST fragment sequence

PAD_TOKEN = "<pad>"
CLS_TOKEN = "<s>"
SEP_TOKEN = "</s>"
MASK_TOKEN = "<mask>"
UNK_TOKEN = "<unk>"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pickle
import tqdm

with open("../ASTBERTa/frag_data.pkl", "rb") as f:
    frag_data = pickle.load(f)

with open("../ASTBERTa/vocab_data.pkl", "rb") as f:
    vocab_data = pickle.load(f)


frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]
frag_type_to_id = frag_data["frag_type_to_id"]

vocab = vocab_data["vocab"]
token_to_id = vocab_data["token_to_id"]
id_to_token = vocab_data["id_to_token"]
special_token_ids = vocab_data["special_token_ids"]

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vocab_size = len(vocab)  # size of vocabulary
intermediate_size = 3072  # embedding dimension
hidden_size = 768

num_hidden_layers = 6
num_attention_heads = 12
dropout = 0.1

config = RobertaConfig(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size,
    hidden_dropout_prob=dropout,
    max_position_embeddings=MAX_FRAGMENT_SEQ_LEN + 2,
)


# Load the ASTBERTa model
tokenizer = ASTTokenizer(vocab, token_to_id, MAX_FRAGMENT_SEQ_LEN, device)
pretrained_model = torch.load("../ASTBERTa/models/final/model_27500.pt")

if isinstance(pretrained_model, torch.nn.DataParallel):
    pretrained_model = pretrained_model.module

ast_net = RobertaForMaskedLM.from_pretrained(
    pretrained_model_name_or_path=None,
    state_dict=pretrained_model.state_dict(),
    config=config,
).to(device)
ast_net = BetterTransformer.transform(ast_net)

In [27]:
from js_ast.fragmentise import hash_frag


def tokenize(frag_seq):
    frag_id_seq: list[int] = []
    frag_id_seq.append(token_to_id[CLS_TOKEN])

    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        if frag_hash in token_to_id:
            frag_id_seq.append(token_to_id[frag_hash])
        else:
            oov_frag: dict[str, str] = {"type": frag["type"]}
            oov_frag_hash = hash_frag(oov_frag)
            if oov_frag_hash in token_to_id:
                frag_id_seq.append(token_to_id[oov_frag_hash])
            else:
                print(f"UNK_TOKEN: {frag_hash}")
                frag_id_seq.append(token_to_id[UNK_TOKEN])
        
        if len(frag_id_seq) >= MAX_FRAGMENT_SEQ_LEN:
            break

    if len(frag_id_seq) < MAX_FRAGMENT_SEQ_LEN:
        frag_id_seq.append(token_to_id[SEP_TOKEN])

    return torch.tensor(frag_id_seq, dtype=torch.long)

In [29]:
acc = []

for seq in tqdm.tqdm(frag_seqs):
    tokenized = tokenize(seq)
    mask_idxs = torch.randint(1, len(tokenized) - 1, (max(1, int(len(tokenized) * 0.1)),))
    masked_frag_types = [frag_id_to_type[frag_id.item()] for frag_id in tokenized[mask_idxs]]

    tokenized[mask_idxs] = token_to_id[MASK_TOKEN]
    inputs = tokenizer.pad_batch([tokenized])
    out = ast_net(**inputs)

    preds = out.logits.argmax(dim=-1).detach().cpu()[0]
    masked_preds = preds[mask_idxs]
    out_frag_types = [frag_id_to_type[frag_id.item()] for frag_id in masked_preds]

    correct = [a == b for a, b in zip(masked_frag_types, out_frag_types)]
    acc.append(sum(correct) / len(correct))

100%|██████████| 14017/14017 [01:07<00:00, 207.49it/s]


In [30]:
print(sum(acc) / len(acc))

0.9211684989041612
