# Read Data

In [1]:
from datasets import load_dataset
import numpy as np
import pickle
from parser import remove_comments_and_docstrings
from tqdm import tqdm

class Example(object):
    def __init__(self, code, desc, lang):
        self.code = code
        self.desc = desc
        self.lang = lang
        
def read_examples():
    dataset = load_dataset('code_search_net')
    langs = ['javascript']
    
    np.random.seed(2789)
    train = []
    sampled = np.random.choice(range(len(dataset['train'])), int(0.05*len(dataset['train'])), replace=False )
    sampled = list(map(int, sampled))
    for i in tqdm(sampled):
        if dataset['train'][i]['language'] in langs:
            code, desc, lang = dataset['train'][i]['func_code_string'], \
                        dataset['train'][i]['func_documentation_string'], dataset['train'][i]['language']
            try:
                code = remove_comments_and_docstrings(code, lang)
                if lang=="php":
                    if not(code.startswith('<?php')):
                        code="<?php"+code+"?>"
                train.append( Example(code, desc, lang) )  
            except:
                pass
            if len(train)>=100:
                break
            
    valid = []
    for sample in tqdm(dataset['validation']):
        if sample['language'] in langs:
            code, desc, lang = sample['func_code_string'], \
                        sample['func_documentation_string'], sample['language']
            try:
                code = remove_comments_and_docstrings(code, lang)
                if lang=="php":
                    if not(code.startswith('<?php')):
                        code="<?php"+code+"?>" 
                valid.append( Example(code, desc, lang) )
            except:
                continue 
            if len(valid)>=10:
                break
                
    test = []
    for sample in tqdm(dataset['test']):
        if sample['language'] in langs:
            code, desc, lang = sample['func_code_string'], \
                        sample['func_documentation_string'], sample['language']
            try:
                code = remove_comments_and_docstrings(code, lang)
                if lang=="php":
                    if not(code.startswith('<?php')):
                        code="<?php"+code+"?>" 
                test.append( Example(code, desc, lang) )
            except:
                continue  
            if len(test)>=10:
                break
    return list(train), list(valid), list(test)

train_examples, valid_examples, test_examples = read_examples()

No config specified, defaulting to: code_search_net/all
Reusing dataset code_search_net (/home/anoushkav/.cache/huggingface/datasets/code_search_net/all/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27)


  0%|          | 0/3 [00:00<?, ?it/s]

  2%|███▎                                                                                                                                                                                                              | 1503/94042 [00:00<00:26, 3484.39it/s]
 43%|██████████████████████████████████████████████████████████████████████████████████████████                                                                                                                       | 38444/89154 [00:06<00:09, 5606.11it/s]
 49%|█████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                          | 49094/100529 [00:08<00:09, 5500.63it/s]


# Preprocess Data

In [2]:
from parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
from parser import (remove_comments_and_docstrings,
                   tree_to_token_index,
                   index_to_code_token,
                   tree_to_variable_index, 
                   detokenize_code)
from tree_sitter import Language, Parser
from transformers import RobertaTokenizer
import pickle
from collections import defaultdict
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
import os


dfg_function={
    'python':DFG_python,
    'java':DFG_java,
    'php':DFG_php,
    'javascript':DFG_javascript,
}

parsers={}        
for lang in dfg_function:
    LANGUAGE = Language('parser/my-languages.so', lang)
    parser = Parser()
    parser.set_language(LANGUAGE) 
    parser = [parser,dfg_function[lang]]    
    parsers[lang]= parser
    
class InputFeatures(object):
    def __init__(self,
                 example_index,
                 code_tokens_ids,
                 desc_tokens_ids,
                 ast_node_types, 
                 ast_adj,
                 graph_feature
    ):
        self.example_index = example_index
        self.code_tokens_ids = code_tokens_ids
        self.desc_tokens_ids = desc_tokens_ids
        self.ast_node_types = ast_node_types
        self.ast_adj = ast_adj
        self.graph_feature =  graph_feature
    
def gather_node_types(node):
    global ast_node_types
    ast_node_types.append(node.type)
    def helper(node):
        global ast_node_types
        for child in node.children:
            ast_node_types.append(child.type)
            helper(child)
    helper(node)
        
def get_leaf_nodes(node, output):
    if node.children==[]:
        output.append(node)
    for node in node.children:
        get_leaf_nodes(node, output)
        
def overlap(s1,e1,s2,e2):
    if s1[0]!=e1[0]:
        raise Exception()
    if s1[0]==s2[0]:
        return (s1[1]<=s2[1]<e1[1])|(s2[1]<=s1[1]<e2[1])
    return False

def get_lr_path(leaf):
    path = [leaf]
    while path[-1].parent is not None:
        path.append(path[-1].parent)
    return path
        
def get_ll_sim(p1, p2):
    common = 1
    for i in range(2, min(len(p1)-1, len(p2)-1)):
        if p1[-i]==p2[-i]:
            common += 1
        else:
            break
    return common*common / (len(p1)*len(p2))

def get_adj_list(ast):
    node=ast.root_node
    global adj
    adj=defaultdict(list)
    global counter
    counter=0
    def helper(node):
        global adj
        global counter 
        c=counter
        for child in node.children:
            counter+=1
            adj[c].append(counter)
            helper(child)
    helper(node)
    return adj

def get_graph_feature_list(ast,dict_node_types):
    node=ast.root_node
    global graph_feature
    graph_feature=[]
    graph_feature.append(dict_node_types[node.type])
    def helper(node):
        global graph_feature
        for child in node.children:
            graph_feature.append(dict_node_types[child.type])
            helper(child)
    helper(node)
    return graph_feature
        
          
def convert_examples_to_features(examples, tokenizer):
    features = []
    global ast_node_types
    ast_node_types = []
    
    for example_index, example in enumerate(tqdm(examples,total=len(examples))):
        ast = parsers[example.lang][0].parse(bytes(example.code, 'utf-8'))
        gather_node_types(ast.root_node)
        
    dict_node_types=defaultdict(list)
                        
    label_encoder = LabelEncoder()
    integer_encoded = label_encoder.fit_transform(np.unique(ast_node_types))

    onehot_encoder = OneHotEncoder(sparse=False)
    integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
    onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
    
    for idx, ntype in enumerate(np.unique(ast_node_types)):
        dict_node_types[ntype] = onehot_encoded[idx]
        
    for example_index, example in enumerate(tqdm(examples,total=len(examples))):
   
        code_ids = tokenizer.tokenize(example.code)
        desc_ids = tokenizer.tokenize(example.desc)
                
        code_tokens_ids = tokenizer.convert_tokens_to_ids(code_ids)
        desc_tokens_ids = tokenizer.convert_tokens_to_ids(desc_ids)
                        
        ast = parsers[example.lang][0].parse(bytes(example.code, 'utf-8'))
                
        ast_adj=get_adj_list(ast)
        
            
        graph_feature=get_graph_feature_list(ast,dict_node_types)     
        
        
                
        features.append(
            InputFeatures(
                 example_index,
                 code_tokens_ids, 
                 desc_tokens_ids, 
                 ast_node_types, 
                 ast_adj,
                 graph_feature
            )
        )
    return features

tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
features = convert_examples_to_features(train_examples+valid_examples+test_examples, tokenizer)

pickle.dump({'train':features[:len(train_examples)], 'valid':features[len(train_examples):len(train_examples)+len(valid_examples)],'test':features[len(train_examples)+len(valid_examples):]},
            open('features_pt_javascript.pkl', 'wb'))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:00<00:00, 5287.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:00<00:00, 498.97it/s]
