In [None]:
import pandas as pd
import javalang
from tqdm import tqdm
from javalang.ast import Node

In [None]:
raw_code_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/data.jsonl'
train_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/train.txt'
test_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/test.txt'
valid_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/valid.txt' 

In [None]:
raw_code = pd.read_json(path_or_buf=raw_code_url, lines=True)
raw_code.head()

In [None]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [None]:
syntax_error_indices = []
syntax_error_ids = []
nodes_num = []
tree_list = []
for i in tqdm(range(len(raw_code))):
    try:        
        tree = parse_program(raw_code['func'][i])
        tree_list.append(tree)
    except:
        syntax_error_indices.append(i)
        syntax_error_ids.append(raw_code['idx'][i])

In [None]:
syntax_error_ids

In [None]:
node_nums = [ ]
over_1000_ids = [ ]
for i in range(len(tree_list)):
    sequence = []
    get_sequence(tree_list[i], sequence)
    node_nums.append(len(sequence))

    if len(sequence) > 1000:
        over_1000_ids.append(raw_code['idx'][i])

In [None]:
len(over_1000_ids)

In [None]:
raw_code = raw_code.set_index('idx')
raw_code_index = raw_code.index.tolist()

def read_ccd_pairs(url):
    data = []
    with open(url) as f:
        for line in tqdm(f):
            line = line.strip()
            id1, id2, label = line.split('\t')
            if int(id1) not in raw_code_index or int(id2) not in raw_code_index or int(id1) in over_1000_ids or int(id2) in over_1000_ids:
                continue
            label = 0 if label == '0' else 1
            data.append((int(id1), int(id2), label))
    return data

In [None]:
train_data = read_ccd_pairs(train_url)
valid_data = read_ccd_pairs(valid_url)
test_data = read_ccd_pairs(test_url)

In [None]:
len(train_data)

In [None]:
len(valid_data)

In [None]:
len(test_data)