In [2]:
import pickle
import sys
from pathlib import Path
from typing import Any

import tqdm

from preprocessing.normalise import collect_id, normalize_id
from utils.loader import load_raw_corpus
from utils.logging import setup_logging

sys.setrecursionlimit(20000)
setup_logging()

In [3]:
TERM_TYPE = [
    "DebuggerStatement",
    "ThisExpression",
    "Super",
    "EmptyStatement",
    "Import",
]

MAX_SEQ_LEN = 510

In [3]:
corpus = load_raw_corpus(Path("data/DIE-corpus/"))
for ast in tqdm.tqdm(corpus):
    id_idx = {"v": 0, "f": 0, "c": 0}
    id_map = {}
    collect_id(ast, id_map, id_idx)
    normalize_id(ast, id_map)

with open("ASTBERTa/corpus.pkl", "wb") as f:
    pickle.dump(corpus, f)

  ast = load_ast(code, ast_path)
  ast = load_ast(code, ast_path)
  ast = load_ast(code, ast_path)
100%|██████████| 18162/18162 [01:02<00:00, 289.33it/s] 
 61%|██████▏   | 11016/17942 [00:21<00:13, 502.81it/s] 


KeyboardInterrupt: 

In [4]:
with open("ASTBERTa/corpus.pkl", "rb") as f:
    corpus = pickle.load(f)

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [5]:
from js_ast.nodes import Node


subtrees: list[Node] = []

for ast in tqdm.tqdm(corpus):
    for node in ast.traverse():
        subtrees.append(node)

100%|██████████| 17942/17942 [00:27<00:00, 653.33it/s] 


In [6]:
from js_ast.fragmentise import node_to_frags


frag_seqs: list[list[dict[str, Any]]] = []
frag_info_seqs: list[list[tuple[int, str]]] = []
all_node_types: set[str] = set()


for ast in tqdm.tqdm(corpus):
    frag_seq: list[dict[str, Any]] = []
    frag_info_seq: list[tuple[int, str]] = []
    node_types: set[str] = set()

    node_to_frags(ast, frag_seq, frag_info_seq, node_types)

    frag_seqs.append(frag_seq)
    frag_info_seqs.append(frag_info_seq)
    all_node_types.update(node_types)

100%|██████████| 17942/17942 [00:56<00:00, 319.88it/s] 


In [7]:
frag_seqs_len = list(sorted(map(lambda x: len(x), frag_seqs)))
frag_seqs_below_max = [frag_seqs for x in frag_seqs if len(x) < 512]

print("Length of frag_seqs:", len(frag_seqs))
print("Length of frag_info_seqs:", len(frag_info_seqs))
print("Node types:", all_node_types)

print("Max length of frag_seqs:", frag_seqs_len[-100:])
print("Min length of frag_seqs:", frag_seqs_len[:5])
print("Avg length of frag_seqs:", sum(frag_seqs_len) / len(frag_seqs_len))
print("Percentage of frag_seqs below 1024:", len(frag_seqs_below_max) / len(frag_seqs))

Length of frag_seqs: 17942
Length of frag_info_seqs: 17942
Node types: {'ObjectPattern', 'ArrayPattern', 'VariableDeclarator', 'NewExpression', 'ClassDeclaration', 'TemplateElement', 'MetaProperty', 'TryStatement', 'Literal', 'FunctionExpression', 'LogicalExpression', 'ForInStatement', 'WithStatement', 'CallExpression', 'ConditionalExpression', 'AssignmentPattern', 'ArrayExpression', 'MethodDefinition', 'RestElement', 'ContinueStatement', 'ClassBody', 'UnaryExpression', 'UpdateExpression', 'ObjectExpression', 'AssignmentExpression', 'ForOfStatement', 'ThrowStatement', 'ForStatement', 'TemplateLiteral', 'BlockStatement', 'CatchClause', 'Program', 'Identifier', 'ClassExpression', 'ReturnStatement', 'BinaryExpression', 'DoWhileStatement', 'BreakStatement', 'TaggedTemplateExpression', 'LabeledStatement', 'ArrowFunctionExpression', 'SequenceExpression', 'VariableDeclaration', 'IfStatement', 'WhileStatement', 'Property', 'SwitchStatement', 'SwitchCase', 'FunctionDeclaration', 'ExpressionStat

In [8]:
from collections import defaultdict

from js_ast.fragmentise import hash_frag

frag_freq: dict[str, int] = defaultdict(int)
hash_to_frag: dict[str, dict[str, Any]] = {}
frag_hash_to_type: dict[str, str] = {}

for frag_seq, frag_info_seq in tqdm.tqdm(
    zip(frag_seqs, frag_info_seqs), total=len(frag_seqs)
):
    assert len(frag_seq) == len(frag_info_seq)
    for frag, (_, frag_type) in zip(frag_seq, frag_info_seq):
        frag_hash = hash_frag(frag)
        frag_freq[frag_hash] += 1

        if frag_hash not in hash_to_frag:
            hash_to_frag[frag_hash] = frag

        if frag_hash not in frag_hash_to_type:
            frag_hash_to_type[frag_hash] = frag_type

100%|██████████| 17942/17942 [00:30<00:00, 597.58it/s] 


In [9]:
frag_freq_list = list(sorted(frag_freq.items(), reverse=True, key=lambda x: x[1]))
unique_vocab_frags = set([frag_hash for frag_hash, freq in frag_freq_list if freq > 5])
oov_frags: list[str] = []

# Add OOV anonymous frag type for those not in vocabulary
for frag_type in all_node_types:
    oov_frag = {"type": frag_type}
    oov_frag_hash = hash_frag(oov_frag)
    oov_frags.append(oov_frag_hash)
    frag_hash_to_type[oov_frag_hash] = frag_type

unique_vocab_frags.update(oov_frags)

vocab_frags = list(unique_vocab_frags)

print("Number of unique fragments:", len(frag_freq))
print("Max frequency:", max(frag_freq.values()))
print("Min frequency:", min(frag_freq.values()))

print("Number of unique fragments with freq > 5:", len(vocab_frags))
print([hash_to_frag[frag] for frag in vocab_frags[:10]])

Number of unique fragments: 481301
Max frequency: 512920
Min frequency: 1
Number of unique fragments with freq > 5: 20537
[{'type': 'Literal', 'value': 1801, 'raw': '1801', 'regex': None, 'bigint': None}, {'type': 'Identifier', 'name': 'uniformContext'}, {'type': 'SwitchStatement', 'discriminant': {'type': 'Identifier'}, 'cases': [{'type': 'SwitchCase'}, {'type': 'SwitchCase'}, {'type': 'SwitchCase'}, {'type': 'SwitchCase'}, {'type': 'SwitchCase'}]}, {'type': 'Identifier', 'name': 'v810'}, {'type': 'Identifier', 'name': 'f2167'}, {'type': 'Literal', 'value': 3.3, 'raw': '3.3', 'regex': None, 'bigint': None}, {'type': 'Identifier', 'name': 'blah'}, {'type': 'Literal', 'value': 'search', 'raw': '"search"', 'regex': None, 'bigint': None}, {'type': 'BinaryExpression', 'operator': '<=', 'left': {'type': 'MemberExpression'}, 'right': {'type': 'MemberExpression'}}, {'type': 'BinaryExpression', 'operator': '<', 'left': {'type': 'CallExpression'}, 'right': {'type': 'Literal'}}]


In [10]:
PAD_TOKEN = "<pad>"
CLS_TOKEN = "<s>"
SEP_TOKEN = "</s>"
MASK_TOKEN = "<mask>"
UNK_TOKEN = "<unk>"

special_tokens = [PAD_TOKEN, CLS_TOKEN, MASK_TOKEN, SEP_TOKEN, UNK_TOKEN]

In [11]:
import numpy as np

ordered_vocab = special_tokens + list(vocab_frags)
vocab = set(ordered_vocab)

token_to_id = {token: i for i, token in enumerate(ordered_vocab)}
id_to_token = {i: token for token, i in token_to_id.items()}

special_token_ids = set([token_to_id[token] for token in special_tokens])

# Dictionary for fragment to type, special tokens are mapped to their string representation
frag_type_to_ids: dict[str, list[int]] = defaultdict(list)

for i, frag_hash in enumerate(vocab_frags):
    frag_type = frag_hash_to_type[frag_hash]
    frag_type_to_ids[frag_type].append(token_to_id[frag_hash])
    frag_id_to_type = {i: frag_type}

In [None]:
frag_data = {
    "frag_seqs": frag_seqs,
    "frag_info_seqs": frag_info_seqs,
}

vocab_data = {
    "vocab": vocab,
    "token_to_id": token_to_id,
    "id_to_token": id_to_token,
    "special_token_ids": special_token_ids,
    "frag_type_to_ids": frag_type_to_ids,
    "frag_id_to_type": frag_id_to_type,
}

pickle.dump(frag_data, open("ASTBERTa/frag_data.pkl", "wb"))
pickle.dump(vocab_data, open("ASTBERTa/vocab_data.pkl", "wb"))

In [6]:
import pickle
import tqdm

with open("ASTBERTa/frag_data.pkl", "rb") as f:
    frag_data = pickle.load(f)

with open("ASTBERTa/vocab_data.pkl", "rb") as f:
    vocab_data = pickle.load(f)


frag_seqs = frag_data["frag_seqs"]
frag_info_seqs = frag_data["frag_info_seqs"]

token_to_id = vocab_data["token_to_id"]
id_to_token = vocab_data["id_to_token"]
special_token_ids = vocab_data["special_token_ids"]
vocab = vocab_data["vocab"]
frag_type_to_ids = vocab_data["frag_type_to_ids"]
frag_id_to_type = vocab_data["frag_id_to_type"]

  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)
  exec(code_obj, self.user_global_ns, self.user_ns)


In [14]:
from js_ast.fragmentise import node_to_frags


subtree_frag_seqs: list[list[dict[str, Any]]] = []
subtree_frag_info_seqs: list[list[tuple[int, str]]] = []

for ast in tqdm.tqdm(subtrees):
    subtree_frag_seq: list[dict[str, Any]] = []
    subtree_frag_info_seq: list[tuple[int, str]] = []

    node_to_frags(ast, subtree_frag_seq, subtree_frag_info_seq, set())

    subtree_frag_seqs.append(subtree_frag_seq[:MAX_SEQ_LEN])
    subtree_frag_info_seqs.append(subtree_frag_info_seq[:MAX_SEQ_LEN])

 48%|████▊     | 4658283/9728634 [04:13<18:23, 4596.25it/s]  

: 

: 

In [28]:
lengths = [len(frag_seq) for frag_seq in subtree_frag_seqs]

print("Max length of subtree_frag_seqs:", max(lengths))
print("Min length of subtree_frag_seqs:", min(lengths))
print("Avg length of subtree_frag_seqs:", sum(lengths) / len(lengths))
print("Number of seqs with length = 1:", len([l for l in lengths if l < 1]))

Max length of subtree_frag_seqs: 59
Min length of subtree_frag_seqs: 59
Avg length of subtree_frag_seqs: 59.0
Number of seqs with length < 5: 0


In [22]:
from js_ast.fragmentise import hash_frag

data: list[list[int]] = []

for frag_seq in tqdm.tqdm(subtree_frag_seqs):
    seq: list[int] = []

    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        if frag_hash in vocab:
            seq.append(token_to_id[frag_hash])
        else:
            oov_frag = {"type": frag_hash_to_type[frag_hash]}
            oov_frag_hash = hash_frag(oov_frag)
            if oov_frag_hash in vocab:
                seq.append(token_to_id[oov_frag_hash])
            else:
                print("UNK_TOKEN")
                seq.append(token_to_id[UNK_TOKEN])

    data.append([token_to_id[CLS_TOKEN]] + seq + [token_to_id[SEP_TOKEN]])

100%|██████████| 9728634/9728634 [24:44<00:00, 6555.56it/s]  


In [23]:
import pickle

pickle.dump(data, open("ASTBERTa/data_subtrees.pkl", "wb"))
# pickle.dump(token_to_id, open("ASTBERTa/token_to_id.pkl", "wb"))
# pickle.dump(vocab, open("ASTBERTa/vocab.pkl", "wb"))
# pickle.dump(hash_to_frag, open("ASTBERTa/hash_to_frag.pkl", "wb"))
# pickle.dump(id_to_token, open("ASTBERTa/id_to_token.pkl", "wb"))