In [None]:
import pickle
import sys
from pathlib import Path
from typing import Any

import tqdm

from preprocessing.normalise import collect_id, normalize_id
from utils.loader import load_raw_corpus
from utils.logging import setup_logging

sys.setrecursionlimit(20000)
setup_logging()

In [None]:
TERM_TYPE = set(
    [
        "DebuggerStatement",
        "ThisExpression",
        "Super",
        "EmptyStatement",
        "Import",
    ]
)

MAX_SEQ_LEN = 512

In [None]:
corpus = load_raw_corpus(Path("../corpus/DIE"))
for ast in tqdm.tqdm(corpus):
    id_idx = {"v": 0, "f": 0, "c": 0}
    id_map = {}
    collect_id(ast, id_map, id_idx)
    normalize_id(ast, id_map)

In [None]:
with open("../ASTBERTa/corpus.pkl", "wb") as f:
    pickle.dump(corpus, f)

In [None]:
with open("../ASTBERTa/corpus.pkl", "rb") as f:
    corpus = pickle.load(f)

In [None]:
from js_ast.fragmentise import node_to_frags


frag_seqs: list[list[dict[str, Any]]] = []
frag_info_seqs: list[list[tuple[int, str]]] = []
all_node_types: set[str] = set()

for ast in tqdm.tqdm(corpus):
    frag_seq: list[dict[str, Any]] = []
    node_types: set[str] = set()

    node_to_frags(ast, frag_seq, node_types)

    frag_seqs.append(frag_seq)
    all_node_types.update(node_types)

In [None]:
frag_seqs_len = list(sorted(map(lambda x: len(x), frag_seqs)))
frag_seqs_below_max = [frag_seqs for x in frag_seqs if len(x) < 512]

print("Length of frag_seqs:", len(frag_seqs))
print("Node types:", all_node_types)

print("Max length of frag_seqs:", frag_seqs_len[-100:])
print("Min length of frag_seqs:", frag_seqs_len[:5])
print("Avg length of frag_seqs:", sum(frag_seqs_len) / len(frag_seqs_len))
print("Percentage of frag_seqs below 1024:", len(frag_seqs_below_max) / len(frag_seqs))

In [None]:
from collections import defaultdict

from js_ast.fragmentise import hash_frag

frag_freq: dict[str, int] = defaultdict(int)
hash_to_frag: dict[str, dict[str, Any]] = {}
frag_hash_to_type: dict[str, str] = {}

for frag_seq in tqdm.tqdm(frag_seqs):
    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        frag_freq[frag_hash] += 1

        if frag_hash not in hash_to_frag:
            hash_to_frag[frag_hash] = frag

        if frag_hash not in frag_hash_to_type:
            frag_hash_to_type[frag_hash] = frag["type"]

In [None]:
frag_freq_list = list(sorted(frag_freq.items(), reverse=True, key=lambda x: x[1]))
oov_frags: list[str] = []

# Add OOV anonymous frag type for those not in vocabulary
for frag_type in all_node_types:
    oov_frag = {"type": frag_type}
    oov_frag_hash = hash_frag(oov_frag)
    oov_frags.append(oov_frag_hash)
    frag_hash_to_type[oov_frag_hash] = frag_type
    hash_to_frag[oov_frag_hash] = oov_frag

vocab_size = 20000

threshold_frags = [frag_hash for frag_hash, freq in frag_freq_list if freq > 4]
unique_vocab_frags = set(threshold_frags + oov_frags)

vocab_frags = list(unique_vocab_frags)

print("Number of unique fragments:", len(frag_freq))
print("Max frequency:", max(frag_freq.values()))
print("Min frequency:", min(frag_freq.values()))

print("Number of unique fragments with freq > 4:", len(threshold_frags))
print([hash_to_frag[frag] for frag in vocab_frags[:10]])

In [None]:
PAD_TOKEN = "<pad>"
CLS_TOKEN = "<s>"
SEP_TOKEN = "</s>"
MASK_TOKEN = "<mask>"
UNK_TOKEN = "<unk>"

special_tokens = [PAD_TOKEN, CLS_TOKEN, MASK_TOKEN, SEP_TOKEN, UNK_TOKEN]

In [None]:
import numpy as np

ordered_vocab = special_tokens + list(vocab_frags)
vocab = set(ordered_vocab)

token_to_id = {token: i for i, token in enumerate(ordered_vocab)}
id_to_token = {i: token for token, i in token_to_id.items()}

special_token_ids = set([token_to_id[token] for token in special_tokens])

# Dictionary for fragment to type, special tokens are mapped to their string representation
frag_id_to_type = {token_to_id[frag]: frag_hash_to_type[frag] for frag in vocab_frags}
frag_id_to_frag = {token_to_id[frag]: hash_to_frag[frag] for frag in vocab_frags}

In [None]:
frag_data = {
    "frag_seqs": frag_seqs,
    "frag_id_to_type": frag_id_to_type,
    "frag_id_to_frag": frag_id_to_frag,
}

vocab_data = {
    "vocab": vocab,
    "token_to_id": token_to_id,
    "id_to_token": id_to_token,
    "special_token_ids": special_token_ids,
}

pickle.dump(frag_data, open("../ASTBERTa/frag_data_old.pkl", "wb"))
pickle.dump(vocab_data, open("../ASTBERTa/vocab_data_old.pkl", "wb"))

In [None]:
with open("../ASTBERTa/frag_data_old.pkl", "rb") as f:
    frag_data = pickle.load(f)

frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]

In [None]:
print(len(frag_id_to_type))

In [None]:
vocab = set(list(special_token_ids) + list(frag_id_to_type.keys()))
print(len(vocab))

token_to_id = {}

for frag_id, frag in frag_id_to_frag.items():
    token_to_id[hash_frag(frag)] = frag_id

for i, item in enumerate(special_tokens):
    token_to_id[item] = i

id_to_token = {i: token for token, i in token_to_id.items()}
special_token_ids = set([token_to_id[token] for token in special_tokens])
print(list(token_to_id.items())[:10])

In [None]:
import pickle
import tqdm

with open("../ASTBERTa/frag_data.pkl", "rb") as f:
    frag_data = pickle.load(f)

with open("../ASTBERTa/vocab_data.pkl", "rb") as f:
    vocab_data = pickle.load(f)

frag_seqs = frag_data["frag_seqs"]
frag_id_to_type = frag_data["frag_id_to_type"]
frag_id_to_frag = frag_data["frag_id_to_frag"]

vocab = vocab_data["vocab"]
token_to_id = vocab_data["token_to_id"]
id_to_token = vocab_data["id_to_token"]
special_token_ids = vocab_data["special_token_ids"]

In [None]:
from js_ast.fragmentise import hash_frag

data: list[list[int]] = []

for frag_seq in tqdm.tqdm(frag_seqs):
    seq: list[int] = []

    for frag in frag_seq:
        frag_hash = hash_frag(frag)
        if frag_hash in vocab:
            seq.append(token_to_id[frag_hash])
        else:
            oov_frag = {"type": frag_hash_to_type[frag_hash]}
            oov_frag_hash = hash_frag(oov_frag)
            if oov_frag_hash in vocab:
                seq.append(token_to_id[oov_frag_hash])
            else:
                print("UNK_TOKEN")
                seq.append(token_to_id[UNK_TOKEN])

    data.append([token_to_id[CLS_TOKEN]] + seq + [token_to_id[SEP_TOKEN]])

In [None]:
import pickle

pickle.dump(data, open("../ASTBERTa/data_old.pkl", "wb"))
# pickle.dump(token_to_id, open("ASTBERTa/token_to_id.pkl", "wb"))
# pickle.dump(vocab, open("ASTBERTa/vocab.pkl", "wb"))
# pickle.dump(hash_to_frag, open("ASTBERTa/hash_to_frag.pkl", "wb"))
# pickle.dump(id_to_token, open("ASTBERTa/id_to_token.pkl", "wb"))