In [1]:
import pandas as pd
import os
import sys
import warnings
import javalang
from javalang.ast import Node
import numpy as np


warnings.filterwarnings('ignore')

In [2]:
Train_path = "../../Dataset/AST/log4j/df_log4j_v10.csv"
source = pd.read_csv(Train_path)
source['b_label'] = np.where(source['label']<0.5,0,1)

In [3]:
# Train_path = "../../Dataset/AST/camel/camel-1.2.csv"
# source = pd.read_csv(Train_path)
# source['b_label'] = np.where(source['label']<0.5,0,1)

In [4]:
# Train_path = "../../Dataset/AST/xerces/xerces-1.2.csv"
# source = pd.read_csv(Train_path)
# source['b_label'] = np.where(source['label']<0.5,0,1)

In [5]:
source[:3]

Unnamed: 0.1,Unnamed: 0,metric_name,java_name,file,label,b_label
0,0,org.apache.log4j.helpers.ISO8601DateFormat,./log4j-v_1_0/src/java/org/apache/log4j/helper...,/*\n * Copyright (C) The Apache Software Found...,0,0
1,1,org.apache.log4j.xml.Transform,./log4j-v_1_0/src/java/org/apache/log4j/xml/Tr...,\npackage org.apache.log4j.xml;\n\nimport org....,0,0
2,2,org.apache.log4j.helpers.AppenderAttachableImpl,./log4j-v_1_0/src/java/org/apache/log4j/helper...,/*\n * Copyright (C) The Apache Software Found...,0,0


# Parse

In [6]:
def parse_program(x):
    try:
        tree = javalang.parse.parse(x)
        return tree
    except:
        print("invalid")
    

In [7]:
source['AST'] = source['file'].apply(parse_program)

invalid
invalid
invalid
invalid


# Build-Vocab

In [8]:
def trans_to_sequences(ast):
    sequence = []
    get_sequence(ast, sequence)
    return sequence

def get_sequence(node, sequence):
    token, children = get_token(node), get_children(node)
    sequence.append(token)

    for child in children:
        get_sequence(child, sequence)

    if token in ['ForStatement', 'WhileStatement', 'DoStatement','SwitchStatement', 'IfStatement']:
        sequence.append('End')
        
def get_children(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item
    return list(expand(children))

def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'#node.pop()
    elif isinstance(node, Node):
        token = node.__class__.__name__

    return token

In [9]:
source['AST_seq'] = source['AST'].apply(trans_to_sequences)

In [10]:
from gensim.models.word2vec import Word2Vec

corpus = source['AST_seq']
w2v = Word2Vec(corpus, size=50, workers=16, sg=1, min_count = 50, max_final_vocab=3000)
vocab = w2v.wv.vocab
max_token = w2v.wv.syn0.shape[0]

len(vocab)

95

In [11]:
w2v.save('word2vec_node_50')

# Block-process

In [12]:
class BlockNode(object):
    def __init__(self, node):
        self.node = node
        self.is_str = isinstance(self.node, str)
        self.token = self.get_token(node)
        self.children = self.add_children()

    def is_leaf(self):
        if self.is_str:
            return True
        return len(self.node.children) == 0

    def get_token(self, node):
        if isinstance(node, str):
            token = node
        elif isinstance(node, set):
            token = 'Modifier'
        elif isinstance(node, Node):
            token = node.__class__.__name__
        else:
            token = ''
        return token

    def ori_children(self, root):
        if isinstance(root, Node):
            if self.token in ['MethodDeclaration', 'ConstructorDeclaration']:
                children = root.children[:-1]
            else:
                children = root.children
        elif isinstance(root, set):
            children = list(root)
        else:
            children = []

        def expand(nested_list):
            for item in nested_list:
                if isinstance(item, list):
                    for sub_item in expand(item):
                        yield sub_item
                elif item:
                    yield item

        return list(expand(children))

    def add_children(self):
        if self.is_str:
            return []
        logic = ['SwitchStatement', 'IfStatement', 'ForStatement', 'WhileStatement', 'DoStatement']
        children = self.ori_children(self.node)
        if self.token in logic:
            return [BlockNode(children[0])]
        elif self.token in ['MethodDeclaration', 'ConstructorDeclaration']:
            return [BlockNode(child) for child in children]
        else:
            return [BlockNode(child) for child in children if self.get_token( child) not in logic]


In [13]:
def get_blocks_v1(node, block_seq):
    name, children = get_token(node), get_children(node)
    logic = ['ClassDeclaration', 'MethodDeclaration', 'SwitchStatement','IfStatement', 'ForStatement', 'WhileStatement', 'DoStatement']
    
    if name in ['CompilationUnit']:
        block_seq.append(BlockNode(node))
        body = node.types
        for child in body:
            if get_token(child) not in logic and not hasattr(child, 'block'):
                block_seq.append(BlockNode(child))
            else:
                get_blocks_v1(child, block_seq)

    elif name in ['ClassDeclaration', 'MethodDeclaration', 'ConstructorDeclaration']:
        block_seq.append(BlockNode(node))
        body = node.body
        if body: 
            for child in body:
                if get_token(child) not in logic and not hasattr(child, 'block'):
                    block_seq.append(BlockNode(child))
                else:
                    get_blocks_v1(child, block_seq)
    
    elif name in logic:
        block_seq.append(BlockNode(node))
        for child in children[1:]:
            token = get_token(child)
            if not hasattr(node, 'block') and token not in logic+['BlockStatement']:
                block_seq.append(BlockNode(child))
            else:
                get_blocks_v1(child, block_seq)
            block_seq.append(BlockNode('End'))
    
    elif name is 'BlockStatement' or hasattr(node, 'block'):
        block_seq.append(BlockNode(name))
        for child in children:
            if get_token(child) not in logic:
                block_seq.append(BlockNode(child))
            else:
                get_blocks_v1(child, block_seq)
    else:
        for child in children:
            get_blocks_v1(child, block_seq)
            
def trans2seq(r):
    blocks = []
    get_blocks_v1(r, blocks)
    tree = []
    for b in blocks:
        btree = tree_to_index(b)
        tree.append(btree)
    return tree
        
def tree_to_index(node):
    token = node.token
    result = [vocab[token].index if token in vocab else max_token]
    children = node.children
    for child in children:
        result.append(tree_to_index(child))
    return result

In [14]:
source['block_seq'] = source['AST'].apply(trans2seq)

In [15]:
source[:3]

Unnamed: 0.1,Unnamed: 0,metric_name,java_name,file,label,b_label,AST,AST_seq,block_seq
0,0,org.apache.log4j.helpers.ISO8601DateFormat,./log4j-v_1_0/src/java/org/apache/log4j/helper...,/*\n * Copyright (C) The Apache Software Found...,0,0,CompilationUnit(imports=[Import(path=java.util...,"[CompilationUnit, PackageDeclaration, org.apac...","[[46, [47, [95]], [9, [95]], [9, [95]], [9, [9..."
1,1,org.apache.log4j.xml.Transform,./log4j-v_1_0/src/java/org/apache/log4j/xml/Tr...,\npackage org.apache.log4j.xml;\n\nimport org....,0,0,CompilationUnit(imports=[Import(path=org.apach...,"[CompilationUnit, PackageDeclaration, org.apac...","[[46, [47, [95]], [9, [95]], [9, [95]], [9, [9..."
2,2,org.apache.log4j.helpers.AppenderAttachableImpl,./log4j-v_1_0/src/java/org/apache/log4j/helper...,/*\n * Copyright (C) The Apache Software Found...,0,0,CompilationUnit(imports=[Import(path=org.apach...,"[CompilationUnit, PackageDeclaration, org.apac...","[[46, [47, [95]], [9, [95]], [9, [95]], [9, [9..."


In [16]:
source.to_pickle("parsed_source.pkl")

In [17]:
# Testing

Test_path = "../../Dataset/AST/log4j/df_log4j_v11.csv"
# Test_path = "../../Dataset/AST/camel/camel-1.4.csv"
# Test_path = "../../Dataset/AST/xerces/xerces-1.3.csv"

In [18]:
source = pd.read_csv(Test_path)
source['b_label'] = np.where(source['label']<0.5,0,1)

source['AST'] = source['file'].apply(parse_program)
source['block_seq'] = source['AST'].apply(trans2seq)

invalid
invalid
invalid
invalid


In [19]:
source.to_pickle("parsed_source_test.pkl")

In [20]:
source[:3]

Unnamed: 0.1,Unnamed: 0,metric_name,java_name,file,label,b_label,AST,block_seq
0,0,org.apache.log4j.xml.examples.XCategory,./log4j-v_1_1/src/java/org/apache/log4j/xml/ex...,/*\n * Copyright (C) The Apache Software Found...,1,1,CompilationUnit(imports=[Import(path=org.apach...,"[[46, [47, [95]], [9, [95]], [9, [95]], [9, [9..."
1,1,org.apache.log4j.helpers.DateLayout,./log4j-v_1_1/src/java/org/apache/log4j/helper...,/*\n * Copyright (C) The Apache Software Found...,1,1,CompilationUnit(imports=[Import(path=org.apach...,"[[46, [47, [95]], [9, [95]], [9, [95]], [9, [9..."
2,2,org.apache.log4j.net.test.SocketMin,./log4j-v_1_1/src/java/org/apache/log4j/net/te...,/*\n * Copyright (C) The Apache Software Found...,0,0,CompilationUnit(imports=[Import(path=org.apach...,"[[46, [47, [95]], [9, [95]], [9, [95]], [9, [9..."


In [21]:
max([len(x) for x in source['block_seq']])

274