In [9]:
import pickle
import sys
from pathlib import Path
from typing import Any

import tqdm

from preprocessing.normalise import collect_id, normalize_id
from utils.loader import load_raw_corpus
from utils.logging import setup_logging

sys.setrecursionlimit(20000)
setup_logging()

In [10]:
TERM_TYPE = set(
    [
        "DebuggerStatement",
        "ThisExpression",
        "Super",
        "EmptyStatement",
        "Import",
    ]
)

MAX_SEQ_LEN = 512

In [11]:
corpus = load_raw_corpus(Path("../corpus/DIE"))
for ast in tqdm.tqdm(corpus):
    id_idx = {"v": 0, "f": 0, "c": 0}
    id_map = {}
    collect_id(ast, id_map, id_idx)
    normalize_id(ast, id_map)

  ast = load_ast(code, file, ast_path)
  ast = load_ast(code, file, ast_path)
  ast = load_ast(code, file, ast_path)
100%|██████████| 14199/14199 [00:28<00:00, 504.24it/s] 
100%|██████████| 14017/14017 [00:37<00:00, 375.63it/s] 


In [4]:
with open("../ASTBERTa/corpus.pkl", "wb") as f:
    pickle.dump(corpus, f)

In [5]:
with open("../ASTBERTa/corpus.pkl", "rb") as f:
    corpus = pickle.load(f)

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [12]:
from js_ast.fragmentise import node_to_frags


frag_seqs: list[list[dict[str, Any]]] = []
frag_info_seqs: list[list[tuple[int, str]]] = []
all_node_types: set[str] = set()

for ast in tqdm.tqdm(corpus):
    frag_seq: list[dict[str, Any]] = []
    node_types: set[str] = set()

    node_to_frags(ast, frag_seq, node_types)

    frag_seqs.append(frag_seq)
    all_node_types.update(node_types)

100%|██████████| 14017/14017 [00:28<00:00, 495.07it/s] 


In [13]:
frag_seqs_len = list(sorted(map(lambda x: len(x), frag_seqs)))
frag_seqs_below_max = [frag_seqs for x in frag_seqs if len(x) < 512]

print("Length of frag_seqs:", len(frag_seqs))
print("Node types:", all_node_types)

print("Max length of frag_seqs:", frag_seqs_len[-100:])
print("Min length of frag_seqs:", frag_seqs_len[:5])
print("Avg length of frag_seqs:", sum(frag_seqs_len) / len(frag_seqs_len))
print("Percentage of frag_seqs below 1024:", len(frag_seqs_below_max) / len(frag_seqs))

Length of frag_seqs: 14017
Node types: {'LabeledStatement', 'LogicalExpression', 'Identifier', 'ArrayExpression', 'MemberExpression', 'BreakStatement', 'DoWhileStatement', 'AssignmentPattern', 'ForOfStatement', 'UpdateExpression', 'MethodDefinition', 'ClassDeclaration', 'YieldExpression', 'ForStatement', 'BinaryExpression', 'RestElement', 'ClassExpression', 'TryStatement', 'TemplateElement', 'Literal', 'ClassBody', 'WithStatement', 'MetaProperty', 'ConditionalExpression', 'ReturnStatement', 'ArrowFunctionExpression', 'FunctionExpression', 'NewExpression', 'AssignmentExpression', 'ForInStatement', 'VariableDeclaration', 'UnaryExpression', 'SwitchStatement', 'WhileStatement', 'ArrayPattern', 'TaggedTemplateExpression', 'TemplateLiteral', 'Program', 'AwaitExpression', 'ExpressionStatement', 'CatchClause', 'VariableDeclarator', 'FunctionDeclaration', 'SwitchCase', 'BlockStatement', 'ObjectExpression', 'ThrowStatement', 'ContinueStatement', 'ObjectPattern', 'Property', 'CallExpression', 'Se

In [50]:
from collections import defaultdict

from js_ast.fragmentise import hash_frag

frag_freq: dict[str, int] = defaultdict(int)
hash_to_frag: dict[str, dict[str, Any]] = {}
frag_hash_to_type: dict[str, str] = {}

for frag_seq in tqdm.tqdm(frag_seqs):
    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        frag_freq[frag_hash] += 1

        if frag_hash not in hash_to_frag:
            hash_to_frag[frag_hash] = frag

        if frag_hash not in frag_hash_to_type:
            frag_hash_to_type[frag_hash] = frag["type"]

100%|██████████| 14017/14017 [00:23<00:00, 599.66it/s] 


In [66]:
frag_freq_list = list(sorted(frag_freq.items(), reverse=True, key=lambda x: x[1]))
oov_frags: list[str] = []

# Add OOV anonymous frag type for those not in vocabulary
for frag_type in all_node_types:
    oov_frag = {"type": frag_type}
    oov_frag_hash = hash_frag(oov_frag)
    oov_frags.append(oov_frag_hash)
    frag_hash_to_type[oov_frag_hash] = frag_type
    hash_to_frag[oov_frag_hash] = oov_frag

vocab_size = 20000

threshold_frags = [frag_hash for frag_hash, freq in frag_freq_list if freq > 4]
unique_vocab_frags = set(threshold_frags + oov_frags)

vocab_frags = list(unique_vocab_frags)

print("Number of unique fragments:", len(frag_freq))
print("Max frequency:", max(frag_freq.values()))
print("Min frequency:", min(frag_freq.values()))

print("Number of unique fragments with freq > 4:", len(threshold_frags))
print([hash_to_frag[frag] for frag in vocab_frags[:10]])

Number of unique fragments: 376817
Max frequency: 446351
Min frequency: 1
Number of unique fragments with freq > 4: 16250
[{'type': 'Literal', 'value': 'Done earlyReturnFromNestedTFTC', 'raw': '"Done earlyReturnFromNestedTFTC"', 'regex': None, 'bigint': None}, {'type': 'ArrayExpression', 'elements': [{'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}, {'type': 'FunctionExpression'}]}, {'type': 'Literal', 'value': 755, 'raw': '755', 'regex': None, 'bigint': None}, {'type': 'Identifier', 'name': 'console'}, {'type': 'BinaryExpression', 

In [59]:
PAD_TOKEN = "<pad>"
CLS_TOKEN = "<s>"
SEP_TOKEN = "</s>"
MASK_TOKEN = "<mask>"
UNK_TOKEN = "<unk>"

special_tokens = [PAD_TOKEN, CLS_TOKEN, MASK_TOKEN, SEP_TOKEN, UNK_TOKEN]

In [60]:
import numpy as np

ordered_vocab = special_tokens + list(vocab_frags)
vocab = set(ordered_vocab)

token_to_id = {token: i for i, token in enumerate(ordered_vocab)}
id_to_token = {i: token for token, i in token_to_id.items()}

special_token_ids = set([token_to_id[token] for token in special_tokens])

# Dictionary for fragment to type, special tokens are mapped to their string representation
frag_id_to_type = {token_to_id[frag]: frag_hash_to_type[frag] for frag in vocab_frags}
frag_id_to_frag = {token_to_id[frag]: hash_to_frag[frag] for frag in vocab_frags}

In [62]:
frag_data = {
    "frag_seqs": frag_seqs,
    "frag_id_to_type": frag_id_to_type,
    "frag_id_to_frag": frag_id_to_frag,
}

vocab_data = {
    "vocab": vocab,
    "token_to_id": token_to_id,
    "id_to_token": id_to_token,
    "special_token_ids": special_token_ids,
}

pickle.dump(frag_data, open("../ASTBERTa/frag_data_old.pkl", "wb"))
pickle.dump(vocab_data, open("../ASTBERTa/vocab_data_old.pkl", "wb"))

In [67]:
with open("../ASTBERTa/frag_data_old.pkl", "rb") as f:
    frag_data = pickle.load(f)

frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]

In [68]:
print(len(frag_id_to_type))

29289


In [48]:
vocab = set(list(special_token_ids) + list(frag_id_to_type.keys()))
print(len(vocab))

token_to_id = {}

for frag_id, frag in frag_id_to_frag.items():
    token_to_id[hash_frag(frag)] = frag_id

for i, item in enumerate(special_tokens):
    token_to_id[item] = i

id_to_token = {i: token for token, i in token_to_id.items()}
special_token_ids = set([token_to_id[token] for token in special_tokens])
print(list(token_to_id.items())[:10])

29294
[('11d846149aa81fe2e9510fd04e20c4d2ab934ee11cebd55a08652f2ca340a696', 5), ('9835be32ffc8f319ad5b0145a24af482a4c50ebb13b4ad7c74a49675a93d6d9f', 6), ('b16b7634a2989fce41070893f537b174da89a512cc3cae2dbd065ecf7f683d3e', 7), ('d99c8e32260126c9b9409e3165d75da6f258650a7d02dadb4ab7f7e3debf120c', 8), ('b5a305c3c5791c44f3dd7f9bd11c3ebc0181005319f4ab099fae5f8b2e272a8a', 9), ('da680ac8460e95f0a5b148f82dd93bcfa2bba11ca23351e809597bc78f278a7f', 10), ('75e01a796900fc042894b6835874e7d5a4c8614532d4240974a2095a2dbf9ba4', 11), ('4b1e068e3f69f5d94f1851f66b69414ab4a82bbc40084696177833f1ccb7b7d4', 12), ('6a1a0fb3138dafcf29a526a7a179b760aa3b3c351e8089c549408c52beebd10f', 13), ('8628b19a279dbf1edd748bf47b919b8f9eca4817db1fd654557ab55683e367f4', 14)]


In [63]:
import pickle
import tqdm

with open("../ASTBERTa/frag_data.pkl", "rb") as f:
    frag_data = pickle.load(f)

with open("../ASTBERTa/vocab_data.pkl", "rb") as f:
    vocab_data = pickle.load(f)

frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]

vocab = vocab_data["vocab"]
token_to_id = vocab_data["token_to_id"]
id_to_token = vocab_data["id_to_token"]
special_token_ids = vocab_data["special_token_ids"]

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [64]:
from js_ast.fragmentise import hash_frag

data: list[list[int]] = []

for frag_seq in tqdm.tqdm(frag_seqs):
    seq: list[int] = []

    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        if frag_hash in vocab:
            seq.append(token_to_id[frag_hash])
        else:
            oov_frag = {"type": frag_hash_to_type[frag_hash]}
            oov_frag_hash = hash_frag(oov_frag)
            if oov_frag_hash in vocab:
                seq.append(token_to_id[oov_frag_hash])
            else:
                print("UNK_TOKEN")
                seq.append(token_to_id[UNK_TOKEN])

    data.append([token_to_id[CLS_TOKEN]] + seq + [token_to_id[SEP_TOKEN]])

100%|██████████| 14017/14017 [00:23<00:00, 601.91it/s] 


In [65]:
import pickle

pickle.dump(data, open("../ASTBERTa/data_old.pkl", "wb"))
# pickle.dump(token_to_id, open("ASTBERTa/token_to_id.pkl", "wb"))
# pickle.dump(vocab, open("ASTBERTa/vocab.pkl", "wb"))
# pickle.dump(hash_to_frag, open("ASTBERTa/hash_to_frag.pkl", "wb"))
# pickle.dump(id_to_token, open("ASTBERTa/id_to_token.pkl", "wb"))