# Imports

In [1]:
import os
import sys
import copy
import glob
import tqdm
from torch import nn
import random
import torch
import platform
from typing import Callable, List, Optional, Dict
import numpy as np
import scipy.sparse as sp

import warnings
warnings.filterwarnings('ignore')

from transformers import AutoTokenizer, AutoModel

import torch_geometric
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    Batch
    )
import torch_geometric.datasets as datasets
import torch_geometric.transforms as transforms
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool, global_max_pool

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

import umap
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
from sklearn.cluster import SpectralClustering

from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, silhouette_score

# To ensure determinism
seed = 1234
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed)

# Check versions
print(torch.__version__)
print(torch.version.cuda)
print(platform.python_version())
print(torch_geometric.__version__)



1.8.1+cu101
10.1
3.8.18
1.7.0


# Post-process PDGs

In [2]:
pdg_data_folder = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/pdg_data_java250_100_class"
project_folders = [name for name in os.listdir(pdg_data_folder) if os.path.isdir(os.path.join(pdg_data_folder, name))]
class_id = 0
projects_to_consider = {}
for project in project_folders[:10]:
    projects_to_consider[project] = class_id
    class_id += 1
    
print(projects_to_consider)

{'p00001': 0, 'p00002': 1, 'p00006': 2, 'p02381': 3, 'p02388': 4, 'p02389': 5, 'p02390': 6, 'p02391': 7, 'p02393': 8, 'p02394': 9}


In [3]:
""" ALGORITHM

a. Clean the raw edge info (eg. remove wrongly formatted edges, class edges etc.)
b. Merge same code-lines into a single line/node
c. Remove the duplicate edges
e. Add the all the edges(CD/FD) in the current subgraph

"""

PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS = 0, 0, 0
PRUNING_ERROR_COUNT_IN_DATASET, GOOD_DATA_POINTS_IN_DATASET, TOTAL_DATA_POINTS_IN_DATASET = 0, 0, 0
DATASET_STATISTICS = {}

def get_pruned_pdg(pdg_file, output_pdg_file):
    
    global PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS
    
    # all_edges = [bytes(l, 'utf-8').decode('utf-8', 'ignore').strip()
    #              for l in pdg_file.readlines()]
    all_edges = [l.replace("\n", "").replace("\r", "").strip()
                 for l in pdg_file.readlines()]

    # Remove unnecesssary edges("class" edge, wrongly formatted edges etc.)
    all_edges = [edge for edge in all_edges if edge.find(
        "-->") != -1 and edge.count("$$") == 2]
    all_edges = [edge for edge in all_edges if len(edge.split("-->")) == 2 and
                 len(edge.split("-->")[0].split("$$")) == 2 and
                 len(edge.split("-->")[1].split("$$")) == 2]
    all_edges = [edge for edge in all_edges if edge.split("-->")[0].find("Entry") == -1 and
                 edge.split("-->")[0].find("class") == -1]
    #print("ALL EDGES : \n")
    #print(all_edges, "\n")

    # Merge nodes referring to same code-line
    line_mapping, edge_mapping = {}, {}
    for edge in all_edges:
        node_1, node_2 = edge[:edge.rindex("[")].strip().split("-->")
        edge_type = edge[edge.rindex("[") + 1: -1].strip()
        line_numbers = []
        for node in [node_1, node_2]:
            line_number, line_code = node.strip().split("$$")
            line_number, line_code = line_number.strip(), line_code.strip()
            line_numbers.append(line_number)
            if line_number in line_mapping:
                if line_mapping[line_number] != line_code:
                    line_mapping[line_number] = line_code if len(line_code) > len(
                        line_mapping[line_number]) else line_mapping[line_number]
            else:
                line_mapping[line_number] = line_code
        if tuple(line_numbers) in edge_mapping:
            edge_mapping[tuple(line_numbers)] = list(set(edge_mapping[tuple(line_numbers)] + [edge_type]))
        else:
            edge_mapping[tuple(line_numbers)] = [edge_type]

    # Remove self-loops from subgraph
    edges_temp = {}
    for edge in edge_mapping:
        if edge[0] != edge[1]:
            edges_temp[edge] = edge_mapping[edge]
    edge_mapping = edges_temp
    #print("AFTER REMOVING SELF-LOOPS : \n")
    #print(sub_graph_edges, "\n")

    # Save the pruned PDG
    edge_data_list = []
    for edge in edge_mapping:
        for edge_type in edge_mapping[edge]:
            edge_data = edge[0].strip() + " $$ " + \
                        line_mapping[edge[0]].strip() + " --> " + \
                        edge[1].strip() + " $$ " + \
                        line_mapping[edge[1]].strip() + " [" + \
                        edge_type.strip() + "]\n"
            edge_data_list.append(edge_data)
    #print("FINAL EDGE LIST: \n")
    #print(edge_data_list, "\n")
    if len(edge_data_list) >= 3:
        GOOD_DATA_POINTS += 1
        
    output_pdg_file.writelines(edge_data_list)
    if len(edge_data_list) > 0:
        TOTAL_DATA_POINTS += 1

    return output_pdg_file, len(edge_data_list)

In [4]:
PDG_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/pdg_data_java250_100_class"
OUTPUT_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis"

for project in tqdm.tqdm(projects_to_consider):
    print("\nProcessing : ", project)
    
    OUTPUT_SPLIT_FOLDER_LOCATION = OUTPUT_FOLDER_LOCATION + "/" + project
    if not os.path.exists(OUTPUT_SPLIT_FOLDER_LOCATION):
        os.makedirs(OUTPUT_SPLIT_FOLDER_LOCATION)
        
    INPUT_PROJECT_FOLDER_LOCATION = PDG_FOLDER_LOCATION + "/" + project
    pdg_files = glob.glob(os.path.join(INPUT_PROJECT_FOLDER_LOCATION, '*.txt'))
    
    for pdg_file in pdg_files:
        original_pdg_file = open(pdg_file, 'r')
        project_id = pdg_file[pdg_file.rindex("/")+1:].split("_")[2]
        output_file_location = OUTPUT_SPLIT_FOLDER_LOCATION + "/" + pdg_file[pdg_file.rindex("/")+1:-4] + "_" + str(projects_to_consider[project_id]) + ".txt"
        output_pdg_file = open(output_file_location, "+w")
        try:
            output_pdg_file, no_of_edges = get_pruned_pdg(original_pdg_file, output_pdg_file)
        except Exception as e:
            PRUNING_ERROR_COUNT += 1
            print("\nERROR WHILE PRUNING PDG\n")
            print("\nFile: {}\n".format(pdg_file))
            print("\nERROR: {}\n".format(e))
            original_pdg_file.close()
            output_pdg_file.close()
            os.remove(output_file_location)
        else:
            output_pdg_file.close()
            if no_of_edges == 0:
                os.remove(output_file_location)
            original_pdg_file.close()
            
    print("\nGOOD PDG DATA POINTS: {}\n".format(GOOD_DATA_POINTS))
    print("\nTOTAL PDG DATA POINTS: {}\n".format(TOTAL_DATA_POINTS))
    print("\nTOTAL PRUNING ERROR: {}\n".format(PRUNING_ERROR_COUNT))
    print("\n=================================================================\n")
    PRUNING_ERROR_COUNT_IN_DATASET += PRUNING_ERROR_COUNT
    GOOD_DATA_POINTS_IN_DATASET += GOOD_DATA_POINTS
    TOTAL_DATA_POINTS_IN_DATASET += TOTAL_DATA_POINTS
    DATASET_STATISTICS[project] = [TOTAL_DATA_POINTS, GOOD_DATA_POINTS, PRUNING_ERROR_COUNT]
    PRUNING_ERROR_COUNT, GOOD_DATA_POINTS, TOTAL_DATA_POINTS = 0, 0, 0
    
print("\nTOTAL GOOD PDG DATA POINTS IN DATASET: {}\n".format(GOOD_DATA_POINTS_IN_DATASET))
print("\nTOTAL PDG DATA POINTS IN DATASET: {}\n".format(TOTAL_DATA_POINTS_IN_DATASET))
print("\nTOTAL PRUNING ERROR IN DATASET: {}\n".format(PRUNING_ERROR_COUNT_IN_DATASET))
print("\nDATASET STATISTICS: {}\n".format(DATASET_STATISTICS))

 10%|█         | 1/10 [00:00<00:01,  8.92it/s]


Processing :  p00001

GOOD PDG DATA POINTS: 263


TOTAL PDG DATA POINTS: 263


TOTAL PRUNING ERROR: 0




Processing :  p00002

GOOD PDG DATA POINTS: 245


TOTAL PDG DATA POINTS: 246


TOTAL PRUNING ERROR: 0




Processing :  p00006


 30%|███       | 3/10 [00:00<00:00, 11.21it/s]


GOOD PDG DATA POINTS: 251


TOTAL PDG DATA POINTS: 255


TOTAL PRUNING ERROR: 0




Processing :  p02381


 50%|█████     | 5/10 [00:00<00:00, 10.86it/s]


GOOD PDG DATA POINTS: 280


TOTAL PDG DATA POINTS: 280


TOTAL PRUNING ERROR: 0




Processing :  p02388

GOOD PDG DATA POINTS: 233


TOTAL PDG DATA POINTS: 244


TOTAL PRUNING ERROR: 0




Processing :  p02389


 70%|███████   | 7/10 [00:00<00:00, 12.62it/s]


GOOD PDG DATA POINTS: 246


TOTAL PDG DATA POINTS: 248


TOTAL PRUNING ERROR: 0




Processing :  p02390

GOOD PDG DATA POINTS: 242


TOTAL PDG DATA POINTS: 248


TOTAL PRUNING ERROR: 0




Processing :  p02391

GOOD PDG DATA POINTS: 237


TOTAL PDG DATA POINTS: 238


TOTAL PRUNING ERROR: 0




Processing :  p02393


100%|██████████| 10/10 [00:00<00:00, 11.92it/s]


GOOD PDG DATA POINTS: 250


TOTAL PDG DATA POINTS: 250


TOTAL PRUNING ERROR: 0




Processing :  p02394

GOOD PDG DATA POINTS: 251


TOTAL PDG DATA POINTS: 251


TOTAL PRUNING ERROR: 0




TOTAL GOOD PDG DATA POINTS IN DATASET: 2498


TOTAL PDG DATA POINTS IN DATASET: 2523


TOTAL PRUNING ERROR IN DATASET: 0


DATASET STATISTICS: {'p00001': [263, 263, 0], 'p00002': [246, 245, 0], 'p00006': [255, 251, 0], 'p02381': [280, 280, 0], 'p02388': [244, 233, 0], 'p02389': [248, 246, 0], 'p02390': [248, 242, 0], 'p02391': [238, 237, 0], 'p02393': [250, 250, 0], 'p02394': [251, 251, 0]}






# Process the Graph Data

In [15]:
def get_nodes_edges(inTextFile, add_reverse_edges = False):
  # FD = 0, CD = 1
  # to support the hetero data object as suggested by the documentation 
  nodes_dict = {}
  edge_indices_CD = []
  edge_indices_FD = []

  #to support the Data object as used by the Entities dat object as used in RGAT source code
  edge_indices = []
  edge_type = []
  
  # nodes_dict is an index_map
  node_count=0
  with open(inTextFile) as fp:
    
    file_name = inTextFile.split("/")[-1].strip()

    Lines = fp.readlines()
    for line in Lines:

      N = line.split('-->')
      N[0], N[1] = N[0].strip(), N[1].strip()
      
      #t1 = N[0].split('$$')   
      src = N[0].strip()   
      if src not in nodes_dict.keys():
        nodes_dict[src] = node_count
        node_count+=1
        
      #t2 = N[1].split('$$')
      right_idx = N[1].rfind('[')
      dst = N[1][:right_idx].strip()
      if dst not in nodes_dict.keys():
        nodes_dict[dst] = node_count
        node_count+=1

      x = N[1].strip()[right_idx + 1 : -1].strip()
      if(x == 'FD'):
        y=0
        edge_type.append(y)
        edge_indices.append([nodes_dict[src], nodes_dict[dst]])
        if add_reverse_edges:
          edge_type.append(y)
          edge_indices.append([nodes_dict[dst], nodes_dict[src]])
        edge_indices_FD.append([nodes_dict[src], nodes_dict[dst]])
      else: 
        y=1
        edge_type.append(y)
        edge_indices.append([nodes_dict[src], nodes_dict[dst]])
        if add_reverse_edges:
          edge_type.append(y)
          edge_indices.append([nodes_dict[dst], nodes_dict[src]])
        edge_indices_CD.append([nodes_dict[src], nodes_dict[dst]])
     
  return nodes_dict, edge_indices_FD, edge_indices_CD, edge_indices, edge_type, file_name

In [16]:
import gc

#Set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Initialize the models
codebert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
codebert_model = codebert_model.to(device)

def get_node_embedding_from_codebert(nodes):
    list_of_embeddings = []
    for code_line in nodes.keys():
        code_line = code_line.split("$$")[1].strip()
        code_tokens = codebert_tokenizer.tokenize(code_line, truncation=True, max_length=510)
        tokens = [codebert_tokenizer.cls_token]+code_tokens+[codebert_tokenizer.eos_token]
        tokens_ids = torch.tensor(codebert_tokenizer.convert_tokens_to_ids(tokens))
        tokens_ids = tokens_ids.to(device)
        context_embeddings = codebert_model(tokens_ids[None,:])
        cls_token_embedding = context_embeddings.last_hidden_state[0,0,:]
        list_of_embeddings.append(cls_token_embedding.to("cpu"))
        del tokens_ids
        del context_embeddings
        del cls_token_embedding
    gc.collect()
    torch.cuda.empty_cache()
    return torch.stack(list_of_embeddings)

In [17]:
def create_graph_dataset(folders):
  dataset =[]
  for label, folder in tqdm.tqdm(enumerate(folders)):
    print("\nProcessing: {}\n".format(folder))
    files = glob.glob(os.path.join(folder, '*.txt'))
    print("\nNumber of files: {}\n".format(len(files)))
    count = 0
    for file in files:

      if(count % 5 == 0):
          print("\nAt file: {}\n".format(count))
                        
      try:
          nodes_dict, edge_indices_FD, edge_indices_CD, edge_indices, edge_type, file_name = get_nodes_edges(file, add_reverse_edges = True)
      except Exception as e:
          print("\nError: ", e)
          continue
                    
      if(len(nodes_dict) == 0):
          print("\nNo Data: ", file)
          continue
      #print(nodes_dict, edge_indices_CD, edge_indices_FD, edge_type)

      # Node feature matrix with shape [num_nodes, num_node_features]=(N, 768).
      try:
          with torch.no_grad():
            CodeEmbedding = get_node_embedding_from_codebert(nodes_dict)
      except Exception as e :
          print("\nError: ", e)
          print(nodes_dict)
          continue
      #print(CodeEmbedding.shape)

      # FIXING DATA FOTMATS AND SHAPE
      x = torch.tensor(CodeEmbedding)
      # print(x.shape)
  
      # data.y: Target to train against (may have arbitrary shape),
      # graph-level targets of shape [1, *]
      label = 1
      y = torch.tensor([label], dtype=torch.long)
      #print(type(y))

      # edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]
      edge_index_CD = torch.tensor(edge_indices_CD, dtype=torch.long).t().contiguous()
      edge_index_FD = torch.tensor(edge_indices_FD, dtype=torch.long).t().contiguous()
      edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
      edge_attr = torch.tensor(edge_type, dtype=torch.long).t().contiguous()
      #print(edge_index_CD, edge_index_FD, edge_index, edge_type)
  
      data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x)
      data.id = torch.tensor([count])
      data.y = y
      # data.num_nodes = len(nodes_dict)
      data.api = file_name
      dataset.append(data)
      count += 1
    
  return dataset

In [18]:
OUTPUT_FOLDER_LOCATION = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis"
folders = [os.path.join(OUTPUT_FOLDER_LOCATION, name) for name in os.listdir(OUTPUT_FOLDER_LOCATION) if os.path.isdir(os.path.join(OUTPUT_FOLDER_LOCATION, name))]
print(folders)

gnn_dataset = create_graph_dataset(folders[:5])
print("\nLength of the dataset: ", len(gnn_dataset))

['/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p00001', '/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p00002', '/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p00006', '/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p02381', '/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p02388', '/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p02389', '/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-cla

0it [00:00, ?it/s]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p00001


Number of files: 263


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260



1it [01:31, 91.13s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p00002


Number of files: 246


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245



2it [02:49, 83.44s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p00006


Number of files: 255


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250



3it [04:02, 78.81s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p02381


Number of files: 280


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275



4it [05:55, 92.30s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/temp/solution-classification-analysis/p02388


Number of files: 244


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240



5it [07:01, 84.32s/it]


Length of the dataset:  1288





# Build/Load the Model

## Context-Prediction Model

In [19]:
from model import GNN, GNN_graphpred

#set up model
num_layer = 3
emb_dim = 768
gnn_type = "gcn"
num_tasks = 1
JK = "last"
dropout_ratio = 0.5
graph_pooling = "mean"
input_model_file = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/output/saved_models/gcn_1_3_5_e100_model_ck_code2seq.pth"

gnn_graphpred_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gnn_graphpred_model.from_pretrained(input_model_file)

gnn_model = GNN(num_layer, emb_dim, JK, drop_ratio = dropout_ratio, gnn_type = gnn_type)
gnn_model.load_state_dict(torch.load(input_model_file))

print("Loaded the model!!")

Loaded the model!!


## Clone-Detection Model

In [23]:
from model_ng import CustomGCN


input_model_file = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/output/saved_models/clone_detection_GCN_L3_e10_850k_model.pth"
gnn_model = CustomGCN(num_node_features= 768)
gnn_model.load_state_dict(torch.load(input_model_file))

print("Loaded the model!!")

Loaded the model!!


# Get the Embeddings

In [24]:
gnn_embeddings = []
model_name = "clone-detection" # "context-prediction" or "clone-detection"
for i in range(len(gnn_dataset)):
    if model_name == "clone-detection":
        graph_representation = gnn_model(gnn_dataset[i].x, gnn_dataset[i].edge_index, batch = torch.tensor([0]*(len(gnn_dataset[i].x))))[0]
    else:
        node_representation = gnn_model(gnn_dataset[i].x, gnn_dataset[i].edge_index, gnn_dataset[i].edge_attr)
        graph_representation = global_mean_pool(x = node_representation, batch = torch.tensor([0]*(len(node_representation))))[0]
    gnn_dataset[i].embedding = graph_representation.detach().numpy()
    problem_name = gnn_dataset[i].api.split("_")[2].strip()
    gnn_embeddings.append([gnn_dataset[i].api, problem_name, gnn_dataset[i].embedding])

# Cluster the Embeddings

In [21]:
import copy
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
from sklearn.mixture import GaussianMixture
from sklearn.cluster import AgglomerativeClustering

def cluster_and_compare(embeddings, ground_truth_cluster_number, clustering_algorithm = "Birch"):
    
    if(clustering_algorithm == "Birch"):
        birch_model = Birch(n_clusters = ground_truth_cluster_number)
        clusters_result = birch_model.fit_predict([emb[2] for emb in embeddings])
    elif(clustering_algorithm == "Agglomerative"):
        agglomerative_model = AgglomerativeClustering(n_clusters = ground_truth_cluster_number)
        clusters_result = agglomerative_model.fit_predict([emb[2] for emb in embeddings])
    elif(clustering_algorithm == "KMeans"):
        kmeans_model = KMeans(n_clusters = ground_truth_cluster_number)
        clusters_result = kmeans_model.fit_predict([emb[2] for emb in embeddings])
    elif(clustering_algorithm == "GM"):
        gaussian_model = GaussianMixture(n_components = ground_truth_cluster_number)
        clusters_result = gaussian_model.fit_predict([emb[2] for emb in embeddings])
        
    cluster_count = {}
    
    project_names = list(set([emb[1] for emb in embeddings]))
    cluster_mapping = {}
    for cluster_no in list(set(clusters_result)):
        cluster_mapping[cluster_no] = {}
        
    for i in range(len(clusters_result)):
        try:
            cluster_count[clusters_result[i]] += 1
        except:
            cluster_count[clusters_result[i]] = 1
            
        try:
            cluster_mapping[clusters_result[i]][embeddings[i][1]] += 1
        except:
            cluster_mapping[clusters_result[i]][embeddings[i][1]] = 1
    print("Cluster Counts: ", cluster_count)
    print("Project Names: ", project_names)
    print("Cluster Mapping: ", cluster_mapping)
    
    total_count, currect_count, wrong_count = 0, 0, 0
    both_right, both_wrong = 0, 0
    confusion_matrix = {"TP": 0, "TN": 0, "FP": 0, "FN": 0}

    original_one_final_two = []
    original_two_final_one = []
    
    for i in tqdm.tqdm(range(len(embeddings))):
        for j in range(i+1, len(embeddings)):
            total_count += 1
            if (embeddings[i][1] == embeddings[j][1]):
                if (clusters_result[i] == clusters_result[j]):
                    both_right += 1
                    currect_count += 1
                    confusion_matrix["TP"] += 1
                else:
                    #original_one_final_two.append([embeddings[i][0], embeddings[j][0], embeddings[i][1], clusters_result[i], clusters_result[j]])
                    wrong_count += 1
                    confusion_matrix["FN"] += 1
            else:
                if (clusters_result[i] != clusters_result[j]):
                    both_wrong += 1
                    currect_count += 1
                    confusion_matrix["TN"] += 1
                else:
                    #original_two_final_one.append([embeddings[i][0], embeddings[j][0], embeddings[i][1], embeddings[j][1], clusters_result[i]])
                    wrong_count += 1
                    confusion_matrix["FP"] += 1
                    
    print("total_count = {}, currect_count = {}, wrong_count = {}, both_right = {}, both_wrong = {}".format(total_count, currect_count, wrong_count, both_right, both_wrong))
    print(confusion_matrix)
    precision = float(format(confusion_matrix["TP"] / (confusion_matrix["TP"] + confusion_matrix["FP"]), ".3f"))
    recall = float(format(confusion_matrix["TP"] / (confusion_matrix["TP"] + confusion_matrix["FN"]), ".3f"))
    f1_score = float(format(2 * (precision * recall) / (precision + recall), ".3f"))
    accuracy = float(format(currect_count/total_count, ".3f"))
    print("Precision: {}, Recall: {} and F1-Score: {}".format(precision, recall, f1_score))
    print("Accuracy: {}".format(accuracy))
    
    return precision, recall, f1_score, accuracy

In [25]:
precision, recall, f1_score, accuracy = cluster_and_compare(gnn_embeddings, ground_truth_cluster_number = 5, clustering_algorithm = "Birch")

Cluster Counts:  {0: 1182, 4: 46, 1: 42, 2: 11, 3: 7}
Project Names:  ['p00002', 'p02381', 'p02388', 'p00006', 'p00001']
Cluster Mapping:  {0: {'p00001': 260, 'p00002': 185, 'p00006': 226, 'p02381': 280, 'p02388': 231}, 1: {'p00002': 29, 'p00006': 13}, 2: {'p00002': 8, 'p00006': 3}, 3: {'p00006': 6, 'p02388': 1}, 4: {'p00001': 3, 'p00002': 24, 'p00006': 7, 'p02388': 12}}


  0%|          | 0/1288 [00:00<?, ?it/s]

100%|██████████| 1288/1288 [00:00<00:00, 3180.33it/s]

total_count = 828828, currect_count = 248478, wrong_count = 580350, both_right = 142636, both_wrong = 105842
{'TP': 142636, 'TN': 105842, 'FP': 557307, 'FN': 23043}
Precision: 0.204, Recall: 0.861 and F1-Score: 0.33
Accuracy: 0.3





# Evalute Using CodeBERT/UnixCoder

In [3]:
import gc

#Set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

# Initialize the models
codebert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
codebert_model = codebert_model.to(device)
max_source_length= 512

def get_code_embeddings_from_codebert(codelines):
    gc.collect()
    torch.cuda.empty_cache()
    code = " ".join(codelines)
    source_tokens = codebert_tokenizer.tokenize(code)[:max_source_length-2]
    source_tokens = [codebert_tokenizer.cls_token]+source_tokens+[codebert_tokenizer.sep_token]
    source_ids =  codebert_tokenizer.convert_tokens_to_ids(source_tokens) 
    padding_length = max_source_length - len(source_ids)
    source_ids+=[codebert_tokenizer.pad_token_id]*padding_length
    source_ids = torch.tensor(source_ids)
    
    # tokens = []
    # for code_line in codelines:
    #     code_tokens = codebert_tokenizer.tokenize(code_line, truncation=True, max_length=510)
    #     if tokens == []:
    #         tokens = [codebert_tokenizer.cls_token] + code_tokens
    #     else:
    #         tokens = tokens + [codebert_tokenizer.sep_token] + code_tokens
    # tokens = tokens + [codebert_tokenizer.eos_token]
    # tokens_ids = torch.tensor(codebert_tokenizer.convert_tokens_to_ids(tokens))
    source_ids = source_ids.to(device)
    context_embeddings = codebert_model(source_ids[None,:])
    cls_token_embedding = context_embeddings.last_hidden_state[0,0,:]
    return cls_token_embedding

In [4]:
import gc
from unixcoder import UniXcoder

#Set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

# Initialize the models
unixcoder_model = UniXcoder("microsoft/unixcoder-base")
unixcoder_model = unixcoder_model.to(device)
max_source_length= 512

def get_code_embeddings_from_unixcoder(codelines):
    gc.collect()
    torch.cuda.empty_cache()
    code = " ".join(codelines)
    tokens_ids = unixcoder_model.tokenize([code], max_length=512, mode="<encoder-only>")
    source_ids = torch.tensor(tokens_ids).to(device)
    tokens_embeddings, code_embedding = unixcoder_model(source_ids)
    return torch.flatten(code_embedding)

In [11]:
def get_embeddings_from_llms(folders, model):
  embeddings = []
  for label, folder in tqdm.tqdm(enumerate(folders)):
    folder_name = folder.strip().split("/")[-1]
    print("\nProcessing: {}\n".format(folder_name))
    files = glob.glob(os.path.join(folder, '*.java'))
    print("\nNumber of files: {}\n".format(len(files)))
    count = 1
    for file in files:
      sample_name = file.split("/")[-2].strip()
      file_name = file.split("/")[-1].strip()
      if(count % 5 == 0):
          print("\nAt file: {}\n".format(count))
                        
      fp = open(file,'r')
      lines = fp.readlines()
      lines = [line for line in lines if not line.startswith("import") and not len(line.strip('\n')) == 0]
      lines = [line.strip('\n').strip("\t").strip(" ") for line in lines]
      if model == "codebert":
        embedding = get_code_embeddings_from_codebert(lines)
      elif model == "unixcoder":
        embedding = get_code_embeddings_from_unixcoder(lines)
      embedding = embedding.detach().cpu().numpy()
      embeddings.append([file_name, folder_name, embedding])
      count += 1
    
  return embeddings

In [26]:
project_folders_2 = ["/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p00001",
                     "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p00002"]

project_folders_5 = ["/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p00001",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p00002",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p00006",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p02381",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Benchmarks/CodeNet/processed_data_for_pdg_java250_100_class/p02388"]

In [27]:
with torch.no_grad():
    codebert_embeddings = get_embeddings_from_llms(project_folders_5, "codebert")

precision, recall, f1_score, accuracy = cluster_and_compare(codebert_embeddings, ground_truth_cluster_number = 5, clustering_algorithm = "Birch")

0it [00:00, ?it/s]


Processing: p00001


Number of files: 264


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260



1it [01:19, 79.75s/it]


Processing: p00002


Number of files: 248


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245



2it [02:26, 72.32s/it]


Processing: p00006


Number of files: 258


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255



3it [03:45, 75.15s/it]


Processing: p02381


Number of files: 280


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45



Token indices sequence length is longer than the specified maximum sequence length for this model (597 > 512). Running this sequence through the model will result in indexing errors



At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280



4it [05:11, 79.32s/it]


Processing: p02388


Number of files: 246


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245



5it [06:27, 77.53s/it]


Cluster Counts:  {2: 607, 1: 249, 4: 261, 3: 138, 0: 41}
Project Names:  ['p00002', 'p02381', 'p02388', 'p00006', 'p00001']
Cluster Mapping:  {0: {'p00001': 11, 'p00002': 4, 'p02381': 26}, 1: {'p00001': 75, 'p00002': 19, 'p00006': 7, 'p02381': 148}, 2: {'p00001': 71, 'p00002': 114, 'p00006': 194, 'p02388': 228}, 3: {'p00001': 24, 'p00002': 11, 'p00006': 2, 'p02381': 101}, 4: {'p00001': 83, 'p00002': 100, 'p00006': 55, 'p02381': 5, 'p02388': 18}}


100%|██████████| 1296/1296 [00:00<00:00, 2631.73it/s]

total_count = 839160, currect_count = 578746, wrong_count = 260414, both_right = 83139, both_wrong = 495607
{'TP': 83139, 'TN': 495607, 'FP': 175861, 'FN': 84553}
Precision: 0.321, Recall: 0.496 and F1-Score: 0.39
Accuracy: 0.69





In [28]:
with torch.no_grad():
    unixcoder_embeddings = get_embeddings_from_llms(project_folders_5, "unixcoder")

precision, recall, f1_score, accuracy = cluster_and_compare(unixcoder_embeddings, ground_truth_cluster_number = 5, clustering_algorithm = "Birch")

0it [00:00, ?it/s]


Processing: p00001


Number of files: 264


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260



1it [01:02, 62.32s/it]


Processing: p00002


Number of files: 248


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245



2it [01:59, 59.24s/it]


Processing: p00006


Number of files: 258


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255



3it [02:57, 58.58s/it]


Processing: p02381


Number of files: 280


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280



4it [04:05, 62.50s/it]


Processing: p02388


Number of files: 246


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245



5it [05:01, 60.33s/it]


Cluster Counts:  {0: 263, 2: 572, 1: 290, 3: 99, 4: 72}
Project Names:  ['p00002', 'p02381', 'p02388', 'p00006', 'p00001']
Cluster Mapping:  {0: {'p00001': 258, 'p00002': 2, 'p00006': 3}, 1: {'p00002': 11, 'p02381': 279}, 2: {'p00001': 6, 'p00002': 235, 'p00006': 156, 'p02381': 1, 'p02388': 174}, 3: {'p00006': 99}, 4: {'p02388': 72}}


100%|██████████| 1296/1296 [00:00<00:00, 2700.81it/s]

total_count = 839160, currect_count = 692499, wrong_count = 146661, both_right = 134051, both_wrong = 558448
{'TP': 134051, 'TN': 558448, 'FP': 113020, 'FN': 33641}
Precision: 0.543, Recall: 0.799 and F1-Score: 0.647
Accuracy: 0.825



