# Imports


In [1]:
import os
import sys
import copy
import glob
import tqdm
from torch import nn
import random
import torch
import platform
from typing import Callable, List, Optional, Dict
import numpy as np
import scipy.sparse as sp

import warnings
warnings.filterwarnings('ignore')

from transformers import AutoTokenizer, AutoModel

import torch_geometric
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    Batch
    )
import torch_geometric.datasets as datasets
import torch_geometric.transforms as transforms
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

import umap
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
from sklearn.cluster import SpectralClustering

from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, silhouette_score

# To ensure determinism
seed = 1234
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed)

# Check versions
print(torch.__version__)
print(torch.version.cuda)
print(platform.python_version())
print(torch_geometric.__version__)



1.8.1+cu101
10.1
3.8.18
1.7.0


# Filter Java Code Snippets

## Get all accepted Java submissions

In [8]:
import pandas as pd
import glob
import os

metadata_folder = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/metadata"
dataset_folder = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_data"
metadata_files = glob.glob(os.path.join(metadata_folder, '*.csv'))
allowed_status = ["Accepted"]

accepted_submissions = {}
for file_loc in tqdm.tqdm(metadata_files):
    problem_id = file_loc[file_loc.rindex("/") + 1: -4]
    metadata_file = pd.read_csv(file_loc)
    accepted_submissions[problem_id] = []
    for index, row in metadata_file.iterrows():
        if row["filename_ext"] == "java" and row["status"] in allowed_status:
            code_file_location = f"{dataset_folder}/{row['problem_id']}/{row['submission_id']}.java"
            if os.path.isfile(code_file_location):
                accepted_submissions[problem_id].append(code_file_location)
            else:
                continue
                print(code_file_location)
                print("code file doesn't exist!!")

100%|██████████| 4053/4053 [10:24<00:00,  6.49it/s]


## Filter the code-snippets that have only one class and one method

In [10]:
import javalang
from javalang.parser import JavaSyntaxError

def count_methods(java_file_loc):
  
  with open(java_file_loc, 'r') as f:
    java_source = f.read()

  java_class = javalang.parse.parse(java_source)
  method_count = sum(1 for _, node in java_class.filter(javalang.tree.MethodDeclaration))
  class_count = sum(1 for _, node in java_class.filter(javalang.tree.TypeDeclaration))
  return class_count, method_count


accepted_submissions_filtered = {}
for pid in tqdm.tqdm(accepted_submissions):
    accepted_submissions_filtered[pid] = []
    for path in accepted_submissions[pid]:
        try:
            #print(path)
            class_count, method_count = count_methods(path)
            if class_count == 1 and method_count == 1:
                accepted_submissions_filtered[pid].append(path) 
        except (AssertionError, JavaSyntaxError, javalang.tokenizer.LexerError):
            pass

100%|██████████| 4053/4053 [04:10<00:00, 16.17it/s]


## Get top 5 projects by submission

In [11]:
accepted_submissions_count = {}
for pid in accepted_submissions_filtered:
    accepted_submissions_count[pid] = len(accepted_submissions_filtered[pid])
    
accepted_submissions_count_sorted = sorted(accepted_submissions_count.items(), key=lambda x:x[1], reverse = True)
print(accepted_submissions_count_sorted[:10])

sum_ = 0
for pid in accepted_submissions_count_sorted[:10]:
    sum_ += pid[1]
print("Total accepted java submissions in top 10 problems: ", sum_)

[('p02381', 280), ('p02419', 280), ('p02396', 277), ('p02400', 277), ('p02399', 277), ('p02415', 275), ('p02421', 275), ('p02416', 274), ('p02407', 272), ('p02417', 271)]
Total accepted java submissions in top 10 problems:  2758


## Copy the filtered Java files

In [12]:
import shutil

output_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_accepted_sols_1_method"

for pid, count in tqdm.tqdm(accepted_submissions_count_sorted):
    if count < 20:
        continue
    output_project_folder_location = output_folder_location + "/" + pid
    if not os.path.exists(output_project_folder_location):
        os.mkdir(output_project_folder_location)
    
    for original_file in accepted_submissions_filtered[pid]:
        file_name = original_file[original_file.rindex("/") + 1:]
        shutil.copyfile(original_file, output_project_folder_location + "/" + file_name)
    

100%|██████████| 4053/4053 [00:03<00:00, 1027.39it/s]


# Preprocess for PDG Generation

In [4]:
import sys
import json
import pandas as pd
import random
import os
from pathlib import Path
import re
import glob
import tqdm
import pyparsing

# projects_to_consider = ['p02388', 'p02389', 'p02391', 'p02393', 'p02396']
# projects_to_consider = ['p02381', 'p02419', 'p02396', 'p02400', 'p02399', 'p02415', 'p02421', 'p02416', 'p02407', 'p02417']
input_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_accepted_sols_1_method"

project_folders = [name for name in os.listdir(input_folder_location) if os.path.isdir(os.path.join(input_folder_location, name))]
projects_to_consider = project_folders

In [5]:
input_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_accepted_sols_1_method"

accepted_submissions_filtered = {}
for project in tqdm.tqdm(projects_to_consider):
    input_project_folder = input_folder_location + "/" + project
    java_files = glob.glob(os.path.join(input_project_folder, '*.java'))
    accepted_submissions_filtered[project] = java_files
    
accepted_submissions_count = {}
for pid in accepted_submissions_filtered:
    accepted_submissions_count[pid] = len(accepted_submissions_filtered[pid])
    
accepted_submissions_count_sorted = sorted(accepted_submissions_count.items(), key=lambda x:x[1], reverse = False)
print(accepted_submissions_count_sorted[:10])

sum_ = 0
for pid in accepted_submissions_count_sorted[:10]:
    sum_ += pid[1]
print("Total accepted java submissions in top 10 problems: ", sum_)

100%|██████████| 249/249 [00:00<00:00, 2357.11it/s]

[('p02713', 21), ('p00005', 92), ('p02831', 96), ('p02264', 98), ('p02257', 99), ('p03160', 102), ('p02255', 104), ('p03161', 110), ('p02268', 111), ('p02819', 130)]
Total accepted java submissions in top 10 problems:  963





In [9]:
import sys
import json
import pandas as pd
import random
import os
from pathlib import Path
import re
import glob
import tqdm
import pyparsing

input_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/Java250_accepted_sols_1_method"
output_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class"

def copyFile(input_file_path, output_file_path):
    
    input_file = open(input_file_path, "r")
    newfile = open(output_file_path, "w+")
    
    # Remove all comments
    original_code = input_file.read()
    commentFilter = pyparsing.javaStyleComment.suppress()
    modified_code = commentFilter.transformString(original_code)

    for line in modified_code.split("\n"):
        line = line.replace('\u00A0', " ")
        
        # Skip empty lines, lines like @Test
        if len(line.strip()) == 0 or line.strip().startswith("@"):
            continue
        
        # Remove import statements and packages
        if line.startswith("import") or line.startswith("package") or line.strip().replace("\n", "").strip() == "":
            continue
        
        # Add newline at the end
        if not line.endswith("\n"):
            line += "\n"
        newfile.write(line)
    
    newfile.close()
    input_file.close()

for project in tqdm.tqdm(projects_to_consider):
    input_project_folder = input_folder_location + "/" + project
    java_files = glob.glob(os.path.join(input_project_folder, '*.java'))
    print(project, ":", len(java_files))
    
    output_project_folder_location = output_folder_location + "/" + project
    if not os.path.exists(output_project_folder_location):
        os.mkdir(output_project_folder_location)
        
    for input_java_file_path in java_files:
        file_name = input_java_file_path[input_java_file_path.rindex("/") + 1:]
        output_java_file_path = output_project_folder_location + "/" + file_name
        copyFile(input_java_file_path, output_java_file_path)
        

  0%|          | 0/100 [00:00<?, ?it/s]

p02381 : 280


  1%|          | 1/100 [00:03<06:10,  3.74s/it]

p02419 : 280


  2%|▏         | 2/100 [00:07<05:43,  3.50s/it]

p02396 : 277


  3%|▎         | 3/100 [00:09<05:03,  3.13s/it]

p02400 : 277


  4%|▍         | 4/100 [00:12<04:38,  2.90s/it]

p02399 : 277


  5%|▌         | 5/100 [00:15<04:36,  2.91s/it]

p02415 : 275


  6%|▌         | 6/100 [00:18<04:50,  3.09s/it]

p02421 : 275


  7%|▋         | 7/100 [00:22<05:12,  3.36s/it]

p02416 : 274


  8%|▊         | 8/100 [00:25<05:01,  3.27s/it]

p02407 : 272


  9%|▉         | 9/100 [00:28<04:51,  3.20s/it]

p02417 : 271


 10%|█         | 10/100 [00:33<05:21,  3.58s/it]

p02418 : 270


 11%|█         | 11/100 [00:36<05:08,  3.47s/it]

p03493 : 268


 12%|█▏        | 12/100 [00:38<04:24,  3.01s/it]

p02420 : 268


 13%|█▎        | 13/100 [00:42<04:40,  3.22s/it]

p02397 : 268


 14%|█▍        | 14/100 [00:45<04:32,  3.17s/it]

p02410 : 266


 15%|█▌        | 15/100 [00:49<05:00,  3.53s/it]

p02402 : 265


 16%|█▌        | 16/100 [00:53<04:58,  3.56s/it]

p02413 : 265


 17%|█▋        | 17/100 [00:58<05:30,  3.99s/it]

p00001 : 264


 18%|█▊        | 18/100 [01:01<05:04,  3.72s/it]

p02401 : 262


 19%|█▉        | 19/100 [01:05<05:14,  3.88s/it]

p02414 : 260


 20%|██        | 20/100 [01:10<05:49,  4.37s/it]

p02398 : 258


 21%|██        | 21/100 [01:13<05:09,  3.92s/it]

p00006 : 258


 22%|██▏       | 22/100 [01:15<04:25,  3.41s/it]

p02405 : 257


 23%|██▎       | 23/100 [01:19<04:33,  3.56s/it]

p04043 : 254


 24%|██▍       | 24/100 [01:22<04:07,  3.25s/it]

p02403 : 254


 25%|██▌       | 25/100 [01:25<04:04,  3.26s/it]

p02412 : 253


 26%|██▌       | 26/100 [01:29<04:04,  3.30s/it]

p02404 : 253


 27%|██▋       | 27/100 [01:32<04:09,  3.42s/it]

p02411 : 252


 28%|██▊       | 28/100 [01:37<04:36,  3.85s/it]

p02394 : 251


 29%|██▉       | 29/100 [01:41<04:23,  3.71s/it]

p02390 : 250


 30%|███       | 30/100 [01:43<03:47,  3.24s/it]

p02389 : 250


 31%|███       | 31/100 [01:45<03:25,  2.99s/it]

p03109 : 250


 32%|███▏      | 32/100 [01:48<03:14,  2.86s/it]

p02946 : 250


 33%|███▎      | 33/100 [01:50<02:55,  2.62s/it]

p03470 : 250


 34%|███▍      | 34/100 [01:52<02:55,  2.66s/it]

p02393 : 250


 35%|███▌      | 35/100 [01:56<03:04,  2.85s/it]

p02408 : 249


 36%|███▌      | 36/100 [02:01<03:54,  3.66s/it]

p03315 : 249


 37%|███▋      | 37/100 [02:03<03:19,  3.16s/it]

p02724 : 248


 38%|███▊      | 38/100 [02:05<02:50,  2.75s/it]

p03814 : 248


 39%|███▉      | 39/100 [02:08<02:43,  2.67s/it]

p02718 : 248


 40%|████      | 40/100 [02:11<02:48,  2.82s/it]

p02897 : 248


 41%|████      | 41/100 [02:13<02:30,  2.55s/it]

p03075 : 248


 42%|████▏     | 42/100 [02:15<02:31,  2.62s/it]

p00002 : 248


 43%|████▎     | 43/100 [02:18<02:31,  2.66s/it]

p02880 : 247


 44%|████▍     | 44/100 [02:20<02:20,  2.51s/it]

p02706 : 246


 45%|████▌     | 45/100 [02:23<02:17,  2.50s/it]

p02711 : 246


 46%|████▌     | 46/100 [02:25<02:06,  2.34s/it]

p03370 : 246


 47%|████▋     | 47/100 [02:27<02:09,  2.44s/it]

p02409 : 246


 48%|████▊     | 48/100 [02:33<02:47,  3.22s/it]

p02388 : 246


 49%|████▉     | 49/100 [02:34<02:25,  2.85s/it]

p02688 : 246


 50%|█████     | 50/100 [02:37<02:24,  2.88s/it]

p03242 : 245


 51%|█████     | 51/100 [02:40<02:12,  2.70s/it]

p02694 : 245


 52%|█████▏    | 52/100 [02:42<01:57,  2.44s/it]

p03072 : 245


 53%|█████▎    | 53/100 [02:44<01:55,  2.46s/it]

p03469 : 245


 54%|█████▍    | 54/100 [02:46<01:41,  2.22s/it]

p02700 : 245


 55%|█████▌    | 55/100 [02:48<01:43,  2.30s/it]

p03000 : 244


 56%|█████▌    | 56/100 [02:51<01:45,  2.40s/it]

p02754 : 244


 57%|█████▋    | 57/100 [02:53<01:42,  2.37s/it]

p02681 : 243


 58%|█████▊    | 58/100 [02:56<01:39,  2.37s/it]

p02717 : 241


 59%|█████▉    | 59/100 [02:57<01:31,  2.24s/it]

p03456 : 241


 60%|██████    | 60/100 [03:00<01:31,  2.28s/it]

p02640 : 241


 61%|██████    | 61/100 [03:02<01:27,  2.25s/it]

p03136 : 241


 62%|██████▏   | 62/100 [03:05<01:29,  2.36s/it]

p02676 : 240


 63%|██████▎   | 63/100 [03:07<01:26,  2.35s/it]

p03071 : 240


 64%|██████▍   | 64/100 [03:09<01:20,  2.23s/it]

p02391 : 240


 65%|██████▌   | 65/100 [03:12<01:22,  2.36s/it]

p02772 : 239


 66%|██████▌   | 66/100 [03:14<01:21,  2.39s/it]

p02921 : 239


 67%|██████▋   | 67/100 [03:16<01:18,  2.37s/it]

p02682 : 239


 68%|██████▊   | 68/100 [03:19<01:17,  2.41s/it]

p02922 : 239


 69%|██████▉   | 69/100 [03:21<01:09,  2.26s/it]

p02783 : 239


 70%|███████   | 70/100 [03:23<01:04,  2.16s/it]

p02613 : 238


 71%|███████   | 71/100 [03:26<01:13,  2.52s/it]

p02627 : 238


 72%|███████▏  | 72/100 [03:28<01:07,  2.40s/it]

p02779 : 238


 73%|███████▎  | 73/100 [03:31<01:06,  2.45s/it]

p02910 : 238


 74%|███████▍  | 74/100 [03:34<01:07,  2.61s/it]

p02705 : 238


 75%|███████▌  | 75/100 [03:35<00:56,  2.26s/it]

p03378 : 238


 76%|███████▌  | 76/100 [03:38<00:58,  2.46s/it]

p02987 : 238


 77%|███████▋  | 77/100 [03:41<00:59,  2.61s/it]

p02860 : 238


 78%|███████▊  | 78/100 [03:44<00:57,  2.60s/it]

p03434 : 237


 79%|███████▉  | 79/100 [03:46<00:56,  2.67s/it]

p02993 : 237


 80%|████████  | 80/100 [03:49<00:51,  2.58s/it]

p02923 : 237


 81%|████████  | 81/100 [03:51<00:49,  2.59s/it]

p02899 : 237


 82%|████████▏ | 82/100 [03:54<00:45,  2.55s/it]

p02945 : 237


 83%|████████▎ | 83/100 [03:56<00:40,  2.40s/it]

p03220 : 236


 84%|████████▍ | 84/100 [03:59<00:40,  2.56s/it]

p02677 : 236


 85%|████████▌ | 85/100 [04:02<00:40,  2.73s/it]

p03455 : 236


 86%|████████▌ | 86/100 [04:04<00:35,  2.54s/it]

p03059 : 236


 87%|████████▋ | 87/100 [04:06<00:30,  2.33s/it]

p02911 : 236


 88%|████████▊ | 88/100 [04:09<00:30,  2.56s/it]

p02771 : 235


 89%|████████▉ | 89/100 [04:11<00:27,  2.50s/it]

p03448 : 235


 90%|█████████ | 90/100 [04:14<00:26,  2.63s/it]

p02848 : 235


 91%|█████████ | 91/100 [04:17<00:23,  2.62s/it]

p02953 : 235


 92%|█████████▏| 92/100 [04:20<00:21,  2.71s/it]

p02701 : 235


 93%|█████████▎| 93/100 [04:22<00:18,  2.60s/it]

p02406 : 235


 94%|█████████▍| 94/100 [04:25<00:15,  2.64s/it]

p04030 : 235


 95%|█████████▌| 95/100 [04:28<00:13,  2.65s/it]

p03239 : 235


 96%|█████████▌| 96/100 [04:31<00:10,  2.73s/it]

p02766 : 235


 97%|█████████▋| 97/100 [04:32<00:07,  2.44s/it]

p04044 : 234


 98%|█████████▊| 98/100 [04:35<00:04,  2.45s/it]

p02675 : 234


 99%|█████████▉| 99/100 [04:37<00:02,  2.52s/it]

p02707 : 234


100%|██████████| 100/100 [04:40<00:00,  2.80s/it]


# Get The PGDs

# Postprocess After PDG Generation

In [2]:
import os

# projects_to_consider = ['p02388', 'p02389', 'p02391', 'p02393', 'p02396']
# projects_to_consider = ['p02381', 'p02419', 'p02396', 'p02400', 'p02399', 'p02415', 'p02421', 'p02416', 'p02407', 'p02417']
input_folder_location = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/pdg_data_java250_100_class"

project_folders = [name for name in os.listdir(input_folder_location) if os.path.isdir(os.path.join(input_folder_location, name))]
projects_to_consider = {}
class_id = 0
for project in project_folders:
    projects_to_consider[project] = class_id
    class_id += 1

print(projects_to_consider)

{'p00001': 0, 'p00002': 1, 'p00006': 2, 'p02381': 3, 'p02388': 4, 'p02389': 5, 'p02390': 6, 'p02391': 7, 'p02393': 8, 'p02394': 9, 'p02396': 10, 'p02397': 11, 'p02398': 12, 'p02399': 13, 'p02400': 14, 'p02401': 15, 'p02402': 16, 'p02403': 17, 'p02404': 18, 'p02405': 19, 'p02406': 20, 'p02407': 21, 'p02408': 22, 'p02409': 23, 'p02410': 24, 'p02411': 25, 'p02412': 26, 'p02413': 27, 'p02414': 28, 'p02415': 29, 'p02416': 30, 'p02417': 31, 'p02418': 32, 'p02419': 33, 'p02420': 34, 'p02421': 35, 'p02613': 36, 'p02627': 37, 'p02640': 38, 'p02675': 39, 'p02676': 40, 'p02677': 41, 'p02681': 42, 'p02682': 43, 'p02688': 44, 'p02694': 45, 'p02700': 46, 'p02701': 47, 'p02705': 48, 'p02706': 49, 'p02707': 50, 'p02711': 51, 'p02717': 52, 'p02718': 53, 'p02724': 54, 'p02754': 55, 'p02766': 56, 'p02771': 57, 'p02772': 58, 'p02779': 59, 'p02783': 60, 'p02848': 61, 'p02860': 62, 'p02880': 63, 'p02897': 64, 'p02899': 65, 'p02910': 66, 'p02911': 67, 'p02921': 68, 'p02922': 69, 'p02923': 70, 'p02945': 71, '

In [3]:
import os
import sys
import glob
import tqdm
import random
random.seed(1234)

""" ALGORITHM

a. Clean the raw edge info (eg. remove wrongly formatted edges, class edges etc.)
b. Merge same code-lines into a single line/node
e. Remove self-loops and duplicate edges

"""

PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS = 0, 0, 0
PRUNING_ERROR_COUNT_IN_DATASET, GOOD_DATA_POINTS_IN_DATASET, TOTAL_DATA_POINTS_IN_DATASET = 0, 0, 0
DATASET_STATISTICS = {}

def get_pruned_pdg(pdg_file, output_pdg_file):
    
    global PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS
    
    # all_edges = [bytes(l, 'utf-8').decode('utf-8', 'ignore').strip()
    #              for l in pdg_file.readlines()]
    all_edges = [l.replace("\n", "").replace("\r", "").strip()
                 for l in pdg_file.readlines()]

    # Remove unnecesssary edges("Entry" edge, wrongly formatted edges etc.)
    all_edges = all_edges[1:]
    all_edges = [edge for edge in all_edges if edge.find(
        "-->") != -1 and edge.count("$$") == 2]
    all_edges = [edge for edge in all_edges if len(edge.split("-->")) == 2 and
                 len(edge.split("-->")[0].split("$$")) == 2 and
                 len(edge.split("-->")[1].split("$$")) == 2]

    # Remove self-loops and duplicate edges
    all_edges = [edge for edge in all_edges if edge.split("-->")[0].split("$$")[0].strip() != edge.split("-->")[1].split("$$")[0].strip()]
    all_edges = list(set(all_edges))
        
    if len(all_edges) >= 5:
        GOOD_DATA_POINTS += 1

    all_edges = [edge + "\n" for edge in all_edges]
    output_pdg_file.writelines(all_edges)
    if len(all_edges) > 0:
        TOTAL_DATA_POINTS += 1

    return output_pdg_file, len(all_edges)

def split_data(pdg_files):
    random.shuffle(pdg_files)
    train_split_size, valid_split_size, test_split_size =  int(len(pdg_files) *0.7), \
                                                           int(len(pdg_files) *0.15), \
                                                           int(len(pdg_files) *0.15)
    training_data = pdg_files[:train_split_size]
    validation_data = pdg_files[train_split_size: train_split_size + valid_split_size]
    test_data = pdg_files[train_split_size + valid_split_size:]
    return training_data, validation_data, test_data

PDG_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/pdg_data_java250_100_class"
OUTPUT_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_pdg_data_java250_100_class"

training_data, validation_data, test_data = [], [], []
project_folders = glob.glob(os.path.join(PDG_FOLDER_LOCATION, '*'))
pdg_files_list = []
for project_folder in tqdm.tqdm(project_folders):
    pdg_files = glob.glob(os.path.join(project_folder, '*.txt'))
    s1, s2, s3 = split_data(pdg_files)
    training_data.extend(s1)
    validation_data.extend(s2)
    test_data.extend(s3)
    
random.shuffle(training_data)
random.shuffle(validation_data)
random.shuffle(test_data)
print("\nTraining, validation and test data size: {}, {} and {}".format(len(training_data), len(validation_data), len(test_data)))
print("\nTotal dataset size: ", len(training_data) + len(validation_data) + len(test_data))

for split in [["train", training_data], ["valid", validation_data], ["test", test_data]]:
    print("\nProcessing split: ", split[0])
    OUTPUT_SPLIT_FOLDER_LOCATION = OUTPUT_FOLDER_LOCATION + "/" + split[0]
    if not os.path.exists(OUTPUT_SPLIT_FOLDER_LOCATION):
        os.makedirs(OUTPUT_SPLIT_FOLDER_LOCATION)
    for pdg_file in tqdm.tqdm(split[1]):
        original_pdg_file = open(pdg_file, 'r')
        project_id = pdg_file[pdg_file.rindex("/")+1:].split("_")[2]
        output_file_location = OUTPUT_SPLIT_FOLDER_LOCATION + "/" + pdg_file[pdg_file.rindex("/")+1:-4] + "_" + str(projects_to_consider[project_id]) + ".txt"
        output_pdg_file = open(output_file_location, "+w")
        try:
            output_pdg_file, no_of_edges = get_pruned_pdg(original_pdg_file, output_pdg_file)
        except Exception as e:
            PRUNING_ERROR_COUNT += 1
            print("\nERROR WHILE PRUNING PDG\n")
            print("\nFile: {}\n".format(pdg_file))
            print("\nERROR: {}\n".format(e))
            original_pdg_file.close()
            output_pdg_file.close()
            os.remove(output_file_location)
        else:
            output_pdg_file.close()
            if no_of_edges == 0:
                os.remove(output_file_location)
            original_pdg_file.close()
            
    print("\nGOOD PDG DATA POINTS: {}\n".format(GOOD_DATA_POINTS))
    print("\nTOTAL PDG DATA POINTS: {}\n".format(TOTAL_DATA_POINTS))
    print("\nTOTAL PRUNING ERROR: {}\n".format(PRUNING_ERROR_COUNT))
    print("\n=================================================================\n")
    PRUNING_ERROR_COUNT_IN_DATASET += PRUNING_ERROR_COUNT
    GOOD_DATA_POINTS_IN_DATASET += GOOD_DATA_POINTS
    TOTAL_DATA_POINTS_IN_DATASET += TOTAL_DATA_POINTS
    DATASET_STATISTICS[split[0]] = [TOTAL_DATA_POINTS, GOOD_DATA_POINTS, PRUNING_ERROR_COUNT]
    PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS = 0, 0, 0
    
print("\nTOTAL GOOD PDG DATA POINTS IN DATASET: {}\n".format(GOOD_DATA_POINTS_IN_DATASET))
print("\nTOTAL PDG DATA POINTS IN DATASET: {}\n".format(TOTAL_DATA_POINTS_IN_DATASET))
print("\nTOTAL PRUNING ERROR IN DATASET: {}\n".format(PRUNING_ERROR_COUNT_IN_DATASET))
print("\nDATASET STATISTICS: {}\n".format(DATASET_STATISTICS))

100%|██████████| 100/100 [00:00<00:00, 1236.21it/s]



Training, validation and test data size: 17329, 3676 and 3819

Total dataset size:  24824

Processing split:  train


100%|██████████| 17329/17329 [00:03<00:00, 5003.32it/s]



GOOD PDG DATA POINTS: 17159


TOTAL PDG DATA POINTS: 17309


TOTAL PRUNING ERROR: 0




Processing split:  valid


100%|██████████| 3676/3676 [00:00<00:00, 4919.01it/s]



GOOD PDG DATA POINTS: 3619


TOTAL PDG DATA POINTS: 3670


TOTAL PRUNING ERROR: 0




Processing split:  test


100%|██████████| 3819/3819 [00:00<00:00, 4842.53it/s]


GOOD PDG DATA POINTS: 3777


TOTAL PDG DATA POINTS: 3815


TOTAL PRUNING ERROR: 0




TOTAL GOOD PDG DATA POINTS IN DATASET: 24555


TOTAL PDG DATA POINTS IN DATASET: 24794


TOTAL PRUNING ERROR IN DATASET: 0


DATASET STATISTICS: {'train': [17309, 17159, 0], 'valid': [3670, 3619, 0], 'test': [3815, 3777, 0]}




