# Imports


In [1]:
import os
import sys
import copy
import glob
import tqdm
from torch import nn
import random
import torch
import platform
from typing import Callable, List, Optional, Dict
import numpy as np
import scipy.sparse as sp

import warnings
warnings.filterwarnings('ignore')

from transformers import AutoTokenizer, AutoModel

import torch_geometric
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    Batch
    )
import torch_geometric.datasets as datasets
import torch_geometric.transforms as transforms
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

import umap
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
from sklearn.cluster import SpectralClustering

from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, silhouette_score

# To ensure determinism
seed = 1234
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed)

# Check versions
print(torch.__version__)
print(torch.version.cuda)
print(platform.python_version())
print(torch_geometric.__version__)



1.8.1+cu101
10.1
3.8.18
1.7.0


# Filter Java Code Snippets

## Get all accepted Java submissions

In [8]:
import pandas as pd
import glob
import os

metadata_folder = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/metadata"
dataset_folder = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_data"
metadata_files = glob.glob(os.path.join(metadata_folder, '*.csv'))
allowed_status = ["Accepted"]

accepted_submissions = {}
for file_loc in tqdm.tqdm(metadata_files):
    problem_id = file_loc[file_loc.rindex("/") + 1: -4]
    metadata_file = pd.read_csv(file_loc)
    accepted_submissions[problem_id] = []
    for index, row in metadata_file.iterrows():
        if row["filename_ext"] == "java" and row["status"] in allowed_status:
            code_file_location = f"{dataset_folder}/{row['problem_id']}/{row['submission_id']}.java"
            if os.path.isfile(code_file_location):
                accepted_submissions[problem_id].append(code_file_location)
            else:
                continue
                print(code_file_location)
                print("code file doesn't exist!!")

100%|██████████| 4053/4053 [10:24<00:00,  6.49it/s]


## Filter the code-snippets that have only one class and one method

In [10]:
import javalang
from javalang.parser import JavaSyntaxError

def count_methods(java_file_loc):
  
  with open(java_file_loc, 'r') as f:
    java_source = f.read()

  java_class = javalang.parse.parse(java_source)
  method_count = sum(1 for _, node in java_class.filter(javalang.tree.MethodDeclaration))
  class_count = sum(1 for _, node in java_class.filter(javalang.tree.TypeDeclaration))
  return class_count, method_count


accepted_submissions_filtered = {}
for pid in tqdm.tqdm(accepted_submissions):
    accepted_submissions_filtered[pid] = []
    for path in accepted_submissions[pid]:
        try:
            #print(path)
            class_count, method_count = count_methods(path)
            if class_count == 1 and method_count == 1:
                accepted_submissions_filtered[pid].append(path) 
        except (AssertionError, JavaSyntaxError, javalang.tokenizer.LexerError):
            pass

100%|██████████| 4053/4053 [04:10<00:00, 16.17it/s]


## Get top 5 projects by submission

In [11]:
accepted_submissions_count = {}
for pid in accepted_submissions_filtered:
    accepted_submissions_count[pid] = len(accepted_submissions_filtered[pid])
    
accepted_submissions_count_sorted = sorted(accepted_submissions_count.items(), key=lambda x:x[1], reverse = True)
print(accepted_submissions_count_sorted[:10])

sum_ = 0
for pid in accepted_submissions_count_sorted[:10]:
    sum_ += pid[1]
print("Total accepted java submissions in top 10 problems: ", sum_)

[('p02381', 280), ('p02419', 280), ('p02396', 277), ('p02400', 277), ('p02399', 277), ('p02415', 275), ('p02421', 275), ('p02416', 274), ('p02407', 272), ('p02417', 271)]
Total accepted java submissions in top 10 problems:  2758


## Copy the filtered Java files

In [12]:
import shutil

output_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_accepted_sols_1_method"

for pid, count in tqdm.tqdm(accepted_submissions_count_sorted):
    if count < 20:
        continue
    output_project_folder_location = output_folder_location + "/" + pid
    if not os.path.exists(output_project_folder_location):
        os.mkdir(output_project_folder_location)
    
    for original_file in accepted_submissions_filtered[pid]:
        file_name = original_file[original_file.rindex("/") + 1:]
        shutil.copyfile(original_file, output_project_folder_location + "/" + file_name)
    

100%|██████████| 4053/4053 [00:03<00:00, 1027.39it/s]


# Preprocess for PDG Generation

In [13]:
import sys
import json
import pandas as pd
import random
import os
from pathlib import Path
import re
import glob
import tqdm
import pyparsing

input_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_accepted_sols_1_method"
output_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250"

#projects_to_consider = ['p02388', 'p02389', 'p02391', 'p02393', 'p02396']
projects_to_consider = ['p02381', 'p02419', 'p02396', 'p02400', 'p02399', 'p02415', 'p02421', 'p02416', 'p02407', 'p02417']

def copyFile(input_file_path, output_file_path):
    
    input_file = open(input_file_path, "r")
    newfile = open(output_file_path, "w+")
    
    # Remove all comments
    original_code = input_file.read()
    commentFilter = pyparsing.javaStyleComment.suppress()
    modified_code = commentFilter.transformString(original_code)

    for line in modified_code.split("\n"):
        line = line.replace('\u00A0', " ")
        
        # Skip empty lines, lines like @Test
        if len(line.strip()) == 0 or line.strip().startswith("@"):
            continue
        
        # Remove import statements and packages
        if line.startswith("import") or line.startswith("package") or line.strip().replace("\n", "").strip() == "":
            continue
        
        # Add newline at the end
        if not line.endswith("\n"):
            line += "\n"
        newfile.write(line)
    
    newfile.close()
    input_file.close()

for project in tqdm.tqdm(projects_to_consider):
    input_project_folder = input_folder_location + "/" + project
    java_files = glob.glob(os.path.join(input_project_folder, '*.java'))
    
    output_project_folder_location = output_folder_location + "/" + project
    if not os.path.exists(output_project_folder_location):
        os.mkdir(output_project_folder_location)
        
    for input_java_file_path in java_files:
        file_name = input_java_file_path[input_java_file_path.rindex("/") + 1:]
        output_java_file_path = output_project_folder_location + "/" + file_name
        copyFile(input_java_file_path, output_java_file_path)
        

100%|██████████| 10/10 [00:33<00:00,  3.40s/it]


# Get The PGDs

# Postprocess After PDG Generation

In [2]:
import os
import sys
import glob
import tqdm
import random
random.seed(1234)

""" ALGORITHM

a. Clean the raw edge info (eg. remove wrongly formatted edges, class edges etc.)
b. Merge same code-lines into a single line/node
e. Remove self-loops and duplicate edges

"""

PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS = 0, 0, 0
PRUNING_ERROR_COUNT_IN_DATASET, GOOD_DATA_POINTS_IN_DATASET, TOTAL_DATA_POINTS_IN_DATASET = 0, 0, 0
DATASET_STATISTICS = {}

def get_pruned_pdg(pdg_file, output_pdg_file):
    
    global PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS
    
    # all_edges = [bytes(l, 'utf-8').decode('utf-8', 'ignore').strip()
    #              for l in pdg_file.readlines()]
    all_edges = [l.replace("\n", "").replace("\r", "").strip()
                 for l in pdg_file.readlines()]

    # Remove unnecesssary edges("Entry" edge, wrongly formatted edges etc.)
    all_edges = all_edges[1:]
    all_edges = [edge for edge in all_edges if edge.find(
        "-->") != -1 and edge.count("$$") == 2]
    all_edges = [edge for edge in all_edges if len(edge.split("-->")) == 2 and
                 len(edge.split("-->")[0].split("$$")) == 2 and
                 len(edge.split("-->")[1].split("$$")) == 2]

    # Remove self-loops and duplicate edges
    all_edges = [edge for edge in all_edges if edge.split("-->")[0].split("$$")[0].strip() != edge.split("-->")[1].split("$$")[0].strip()]
    all_edges = list(set(all_edges))
        
    if len(all_edges) >= 5:
        GOOD_DATA_POINTS += 1

    all_edges = [edge + "\n" for edge in all_edges]
    output_pdg_file.writelines(all_edges)
    if len(all_edges) > 0:
        TOTAL_DATA_POINTS += 1

    return output_pdg_file, len(all_edges)

def split_data(pdg_files):
    random.shuffle(pdg_files)
    train_split_size, valid_split_size, test_split_size =  int(len(pdg_files) *0.8), \
                                                           int(len(pdg_files) *0.1), \
                                                           int(len(pdg_files) *0.1)
    training_data = pdg_files[:train_split_size]
    validation_data = pdg_files[train_split_size: train_split_size + valid_split_size]
    test_data = pdg_files[train_split_size + valid_split_size:]
    return training_data, validation_data, test_data

PDG_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/pdg_data_java250"
OUTPUT_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_pdg_data_java250"
#projects_to_consider = {'p02388': 0, 'p02389': 1, 'p02391': 2, 'p02393': 3, 'p02396': 4}
projects_to_consider = {'p02381' : 0, 'p02419': 1, 'p02396': 2, 'p02400': 3, 'p02399': 4, 'p02415': 5, 'p02421': 6, 'p02416': 7, 'p02407': 8, 'p02417': 9}

training_data, validation_data, test_data = [], [], []
project_folders = glob.glob(os.path.join(PDG_FOLDER_LOCATION, '*'))
pdg_files_list = []
for project_folder in tqdm.tqdm(project_folders):
    pdg_files = glob.glob(os.path.join(project_folder, '*.txt'))
    s1, s2, s3 = split_data(pdg_files)
    training_data.extend(s1)
    validation_data.extend(s2)
    test_data.extend(s3)
    
random.shuffle(training_data)
random.shuffle(validation_data)
random.shuffle(test_data)
print("\nTraining, validation and test data size: {}, {} and {}".format(len(training_data), len(validation_data), len(test_data)))
print("\nTotal dataset size: ", len(training_data) + len(validation_data) + len(test_data))

for split in [["train", training_data], ["valid", validation_data], ["test", test_data]]:
    print("\nProcessing split: ", split[0])
    OUTPUT_SPLIT_FOLDER_LOCATION = OUTPUT_FOLDER_LOCATION + "/" + split[0]
    if not os.path.exists(OUTPUT_SPLIT_FOLDER_LOCATION):
        os.makedirs(OUTPUT_SPLIT_FOLDER_LOCATION)
    for pdg_file in tqdm.tqdm(split[1]):
        original_pdg_file = open(pdg_file, 'r')
        project_id = pdg_file[pdg_file.rindex("/")+1:].split("_")[2]
        output_file_location = OUTPUT_SPLIT_FOLDER_LOCATION + "/" + pdg_file[pdg_file.rindex("/")+1:-4] + "_" + str(projects_to_consider[project_id]) + ".txt"
        output_pdg_file = open(output_file_location, "+w")
        try:
            output_pdg_file, no_of_edges = get_pruned_pdg(original_pdg_file, output_pdg_file)
        except Exception as e:
            PRUNING_ERROR_COUNT += 1
            print("\nERROR WHILE PRUNING PDG\n")
            print("\nFile: {}\n".format(pdg_file))
            print("\nERROR: {}\n".format(e))
            original_pdg_file.close()
            output_pdg_file.close()
            os.remove(output_file_location)
        else:
            output_pdg_file.close()
            if no_of_edges == 0:
                os.remove(output_file_location)
            original_pdg_file.close()
            
    print("\nGOOD PDG DATA POINTS: {}\n".format(GOOD_DATA_POINTS))
    print("\nTOTAL PDG DATA POINTS: {}\n".format(TOTAL_DATA_POINTS))
    print("\nTOTAL PRUNING ERROR: {}\n".format(PRUNING_ERROR_COUNT))
    print("\n=================================================================\n")
    PRUNING_ERROR_COUNT_IN_DATASET += PRUNING_ERROR_COUNT
    GOOD_DATA_POINTS_IN_DATASET += GOOD_DATA_POINTS
    TOTAL_DATA_POINTS_IN_DATASET += TOTAL_DATA_POINTS
    DATASET_STATISTICS[split[0]] = [TOTAL_DATA_POINTS, GOOD_DATA_POINTS, PRUNING_ERROR_COUNT]
    PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS = 0, 0, 0
    
print("\nTOTAL GOOD PDG DATA POINTS IN DATASET: {}\n".format(GOOD_DATA_POINTS_IN_DATASET))
print("\nTOTAL PDG DATA POINTS IN DATASET: {}\n".format(TOTAL_DATA_POINTS_IN_DATASET))
print("\nTOTAL PRUNING ERROR IN DATASET: {}\n".format(PRUNING_ERROR_COUNT_IN_DATASET))
print("\nDATASET STATISTICS: {}\n".format(DATASET_STATISTICS))

100%|██████████| 10/10 [00:00<00:00, 1007.88it/s]



Training, validation and test data size: 2202, 272 and 282

Total dataset size:  2756

Processing split:  train


  0%|          | 0/2202 [00:00<?, ?it/s]

100%|██████████| 2202/2202 [00:00<00:00, 4667.17it/s]



GOOD PDG DATA POINTS: 2185


TOTAL PDG DATA POINTS: 2202


TOTAL PRUNING ERROR: 0




Processing split:  valid


100%|██████████| 272/272 [00:00<00:00, 4694.18it/s]



GOOD PDG DATA POINTS: 269


TOTAL PDG DATA POINTS: 272


TOTAL PRUNING ERROR: 0




Processing split:  test


100%|██████████| 282/282 [00:00<00:00, 4687.60it/s]


GOOD PDG DATA POINTS: 280


TOTAL PDG DATA POINTS: 282


TOTAL PRUNING ERROR: 0




TOTAL GOOD PDG DATA POINTS IN DATASET: 2734


TOTAL PDG DATA POINTS IN DATASET: 2756


TOTAL PRUNING ERROR IN DATASET: 0


DATASET STATISTICS: {'train': [2202, 2185, 0], 'valid': [272, 269, 0], 'test': [282, 280, 0]}




