# Imports

In [33]:
import os
import sys
import copy
import glob
import tqdm
from torch import nn
import random
import torch
import platform
from typing import Callable, List, Optional, Dict
import numpy as np
import scipy.sparse as sp

import warnings
warnings.filterwarnings('ignore')

from transformers import AutoTokenizer, AutoModel

import torch_geometric
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    Batch
    )
import torch_geometric.datasets as datasets
import torch_geometric.transforms as transforms
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool, global_max_pool

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

import umap
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
from sklearn.cluster import SpectralClustering

from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, silhouette_score

# To ensure determinism
seed = 1234
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed)

# Check versions
print(torch.__version__)
print(torch.version.cuda)
print(platform.python_version())
print(torch_geometric.__version__)

1.8.1+cu101
10.1
3.8.18
1.7.0


# Process the Graph Data

In [34]:
from collections import OrderedDict

def get_nodes_edges(inTextFile, add_reverse_edges = False, api_name = None):
  # FD = 0, CD = 1
  # to support the hetero data object as suggested by the documentation 
  nodes_dict = OrderedDict()
  edge_indices_CD = []
  edge_indices_FD = []

  #to support the Data object as used by the Entities dat object as used in RGAT source code
  edge_indices = []
  edge_type = []
  
  # nodes_dict is an index_map
  node_count=0
  fp = open(inTextFile, "r")
    
  file_name = inTextFile.split("/")[-1].strip()
  Lines = fp.readlines()
  
  # Capture the API nodes first
  number_of_api_nodes = 0
  if api_name != None:
    api_name = api_name[api_name.find("."):] + "("
    for line in Lines:
      nodes = line.split('-->')
      nodes[0], nodes[1] = nodes[0].strip(), nodes[1].strip()
      
      src = nodes[0]  
      if src not in nodes_dict.keys() and api_name in src:
        nodes_dict[src] = node_count
        node_count += 1
        number_of_api_nodes += 1
        
      right_idx = nodes[1].rfind('[')
      dst = nodes[1][:right_idx].strip()
      if dst not in nodes_dict.keys() and api_name in dst:
        nodes_dict[dst] = node_count
        node_count += 1
        number_of_api_nodes += 1
    if number_of_api_nodes == 0:
      print("No API Nodes found!!!!")
    
  # Process each edge
  for line in Lines:

      N = line.split('-->')
      N[0], N[1] = N[0].strip(), N[1].strip()
      
      #t1 = N[0].split('$$')   
      src = N[0].strip()   
      if src not in nodes_dict.keys():
        nodes_dict[src] = node_count
        node_count+=1
        
      #t2 = N[1].split('$$')
      right_idx = N[1].rfind('[')
      dst = N[1][:right_idx].strip()
      if dst not in nodes_dict.keys():
        nodes_dict[dst] = node_count
        node_count+=1

      x = N[1].strip()[right_idx + 1 : -1].strip()
      if(x == 'FD'):
        y=0
        edge_type.append(y)
        edge_indices.append([nodes_dict[src], nodes_dict[dst]])
        if add_reverse_edges:
          edge_type.append(y)
          edge_indices.append([nodes_dict[dst], nodes_dict[src]])
        edge_indices_FD.append([nodes_dict[src], nodes_dict[dst]])
      else: 
        y=1
        edge_type.append(y)
        edge_indices.append([nodes_dict[src], nodes_dict[dst]])
        if add_reverse_edges:
          edge_type.append(y)
          edge_indices.append([nodes_dict[dst], nodes_dict[src]])
        edge_indices_CD.append([nodes_dict[src], nodes_dict[dst]])
     
  return nodes_dict, edge_indices_FD, edge_indices_CD, edge_indices, edge_type, file_name, number_of_api_nodes

In [35]:
import gc

#Set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Initialize the models
codebert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
codebert_model = codebert_model.to(device)

def get_node_embedding_from_codebert(nodes):
    list_of_embeddings = []
    for code_line in nodes.keys():
        code_line = code_line.split("$$")[1].strip()
        code_tokens = codebert_tokenizer.tokenize(code_line, truncation=True, max_length=510)
        tokens = [codebert_tokenizer.cls_token]+code_tokens+[codebert_tokenizer.eos_token]
        tokens_ids = torch.tensor(codebert_tokenizer.convert_tokens_to_ids(tokens))
        tokens_ids = tokens_ids.to(device)
        context_embeddings = codebert_model(tokens_ids[None,:])
        cls_token_embedding = context_embeddings.last_hidden_state[0,0,:]
        list_of_embeddings.append(cls_token_embedding.to("cpu"))
        del tokens_ids
        del context_embeddings
        del cls_token_embedding
    gc.collect()
    torch.cuda.empty_cache()
    return torch.stack(list_of_embeddings)

In [36]:
def create_graph_dataset(folders, add_reverse_edges = False, track_api_nodes = False):
  dataset =[]
  for label, folder in tqdm.tqdm(enumerate(folders)):
    print("\nProcessing: {}\n".format(folder))
    files = glob.glob(os.path.join(folder, '*.txt'))
    print("\nNumber of files: {}\n".format(len(files)))
    count = 0
    for file in files:

      if(count % 5 == 0):
          print("\nAt file: {}\n".format(count))
                        
      try:
          if track_api_nodes:
              api_name = folder.split("/")[-1].strip()
          else:
              api_name = None
          nodes_dict, edge_indices_FD, edge_indices_CD, edge_indices, edge_type, file_name, number_of_api_nodes = get_nodes_edges(file, add_reverse_edges = add_reverse_edges, api_name = api_name)
      except Exception as e:
          print("\nError: ", e)
          continue
                    
      if(len(nodes_dict) == 0):
          print("\nNo Data: ", file)
          continue
      #print(nodes_dict, edge_indices_CD, edge_indices_FD, edge_type)

      # Node feature matrix with shape [num_nodes, num_node_features]=(N, 768).
      try:
          with torch.no_grad():
            CodeEmbedding = get_node_embedding_from_codebert(nodes_dict)
      except Exception as e :
          print("\nError: ", e)
          print(nodes_dict)
          continue
      #print(CodeEmbedding.shape)

      # FIXING DATA FOTMATS AND SHAPE
      x = torch.tensor(CodeEmbedding)
      # print(x.shape)
  
      # data.y: Target to train against (may have arbitrary shape),
      # graph-level targets of shape [1, *]
      label = 1
      y = torch.tensor([label], dtype=torch.long)
      #print(type(y))

      # edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]
      edge_index_CD = torch.tensor(edge_indices_CD, dtype=torch.long).t().contiguous()
      edge_index_FD = torch.tensor(edge_indices_FD, dtype=torch.long).t().contiguous()
      edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
      edge_attr = torch.tensor(edge_type, dtype=torch.long).t().contiguous()
      #print(edge_index_CD, edge_index_FD, edge_index, edge_type)
  
      data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x)
      data.id = torch.tensor([count])
      data.y = y
      data.number_of_api_nodes = number_of_api_nodes
      data.api = file_name
      dataset.append(data)
      count += 1
    
  return dataset

In [37]:
file_1 = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/BufferedReader.read/0_sample-0_BufferedReader.read_graph_dump.txt"
file_2 = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/Calendar.getTime/0_sample-11_Calendar.getTime_graph_dump.txt"

get_nodes_edges(file_2, add_reverse_edges = False, api_name = "Calendar.getTime")

(OrderedDict([('Line_5 $$ String actual = new SimpleDateFormat(format).format(c.getTime())',
               0),
              ('Line_2 $$ public void matches(Object object)', 1),
              ('Line_3 $$ if (object instanceof Calendar)', 2),
              ('Line_4 $$ Calendar c = (Calendar) object', 3),
              ('Line_6 $$ return value.equals(actual)', 4)]),
 [[1, 2], [3, 0], [0, 4]],
 [[1, 2], [2, 3], [2, 0], [2, 4]],
 [[1, 2], [1, 2], [2, 3], [2, 0], [3, 0], [2, 4], [0, 4]],
 [0, 1, 1, 1, 0, 1, 0],
 '0_sample-11_Calendar.getTime_graph_dump.txt',
 1)

In [38]:
project_folders_2 = ["/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/BufferedReader.read",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/Calendar.getTime"]

project_folders_5 = ["/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/BufferedReader.read",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/Calendar.getTime",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/ClassLoader.getResource",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/DateFormat.format",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/ExecutorService.submit"]

gnn_dataset = create_graph_dataset(project_folders_5, add_reverse_edges = True, track_api_nodes = True)
print("\nLength of the dataset: ", len(gnn_dataset))

0it [00:00, ?it/s]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/BufferedReader.read


Number of files: 467


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 

1it [01:52, 112.01s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/Calendar.getTime


Number of files: 534


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280

2it [03:52, 117.24s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/ClassLoader.getResource


Number of files: 92


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90



3it [04:14, 73.68s/it] 


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/DateFormat.format


Number of files: 534


At file: 0


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 28

4it [06:18, 93.32s/it]


Processing: /home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_pruning/NEW/ExecutorService.submit


Number of files: 19


At file: 0


At file: 5


At file: 10


At file: 15



5it [06:22, 76.52s/it]


Length of the dataset:  1646





# Build/Load the Model

## Context-Prediction Model

In [39]:
from model import GNN, GNN_graphpred

#set up model
num_layer = 3
emb_dim = 768
gnn_type = "gcn"
num_tasks = 1
JK = "last"
dropout_ratio = 0.5
graph_pooling = "mean"
input_model_file = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/output/saved_models/gcn_1_3_5_e100_model_ck_code2seq.pth"

gnn_graphpred_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gnn_graphpred_model.from_pretrained(input_model_file)

gnn_model = GNN(num_layer, emb_dim, JK, drop_ratio = dropout_ratio, gnn_type = gnn_type)
gnn_model.load_state_dict(torch.load(input_model_file))

print("Loaded the model!!")

Loaded the model!!


## Clone-Detection Model

In [44]:
class CustomGCN_Basic(torch.nn.Module):
    def __init__(self, num_node_features = 768):
        super(CustomGCN_Basic, self).__init__()
        self.conv1 = GCNConv(num_node_features, 100)
        self.conv2 = GCNConv(100, 64)
        self.conv3 = GCNConv(64, 32)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        return x
    
input_model_file = "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Graph-Models/MuGNN/output/saved_models/clone_detection_GCN_L3_e10_850k_model.pth"
gnn_model = CustomGCN_Basic(num_node_features= 768)
gnn_model.load_state_dict(torch.load(input_model_file))

print("Loaded the model!!")

Loaded the model!!


# Get the Embeddings

## Pool over all nodes

In [23]:
gnn_embeddings = []
model_name = "context-prediction" # "context-prediction" or "clone-detection"
for i in range(len(gnn_dataset)):
    if model_name == "clone-detection":
        node_representation = gnn_model(gnn_dataset[i].x, gnn_dataset[i].edge_index)
    else:
        node_representation = gnn_model(gnn_dataset[i].x, gnn_dataset[i].edge_index, gnn_dataset[i].edge_attr)
    graph_representation = global_mean_pool(x = node_representation, batch = torch.tensor([0]*(len(node_representation))))[0]
    gnn_dataset[i].embedding = graph_representation.detach().numpy()
    api_name = gnn_dataset[i].api.split("_")[2].strip()
    gnn_embeddings.append([gnn_dataset[i].api, api_name, gnn_dataset[i].embedding])

## Pool over only API nodes

In [45]:
gnn_embeddings = []
model_name = "clone-detection" # "context-prediction" or "clone-detection"
for i in range(len(gnn_dataset)):
    number_of_api_nodes = gnn_dataset[i].number_of_api_nodes
    if model_name == "clone-detection":
        node_representation = gnn_model(gnn_dataset[i].x, gnn_dataset[i].edge_index)
    else:
        node_representation = gnn_model(gnn_dataset[i].x, gnn_dataset[i].edge_index, gnn_dataset[i].edge_attr)
    graph_representation = global_mean_pool(x = node_representation[:number_of_api_nodes], batch = torch.tensor([0]*(number_of_api_nodes)))[0]
    gnn_dataset[i].embedding = graph_representation.detach().numpy()
    api_name = gnn_dataset[i].api.split("_")[2].strip()
    gnn_embeddings.append([gnn_dataset[i].api, api_name, gnn_dataset[i].embedding])

# Cluster the Embeddings

In [42]:
import copy
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
from sklearn.mixture import GaussianMixture
from sklearn.cluster import AgglomerativeClustering

def cluster_and_compare(embeddings, ground_truth_cluster_number, clustering_algorithm = "Birch"):
    
    if(clustering_algorithm == "Birch"):
        birch_model = Birch(n_clusters = ground_truth_cluster_number)
        clusters_result = birch_model.fit_predict([emb[2] for emb in embeddings])
    elif(clustering_algorithm == "Agglomerative"):
        agglomerative_model = AgglomerativeClustering(n_clusters = ground_truth_cluster_number)
        clusters_result = agglomerative_model.fit_predict([emb[2] for emb in embeddings])
    elif(clustering_algorithm == "KMeans"):
        kmeans_model = KMeans(n_clusters = ground_truth_cluster_number)
        clusters_result = kmeans_model.fit_predict([emb[2] for emb in embeddings])
    elif(clustering_algorithm == "GM"):
        gaussian_model = GaussianMixture(n_components = ground_truth_cluster_number)
        clusters_result = gaussian_model.fit_predict([emb[2] for emb in embeddings])
        
    cluster_count = {}
    
    project_names = list(set([emb[1] for emb in embeddings]))
    cluster_mapping = {}
    for cluster_no in list(set(clusters_result)):
        cluster_mapping[cluster_no] = {}
        
    for i in range(len(clusters_result)):
        try:
            cluster_count[clusters_result[i]] += 1
        except:
            cluster_count[clusters_result[i]] = 1
            
        try:
            cluster_mapping[clusters_result[i]][embeddings[i][1]] += 1
        except:
            cluster_mapping[clusters_result[i]][embeddings[i][1]] = 1
    print("Cluster Counts: ", cluster_count)
    print("Project Names: ", project_names)
    print("Cluster Mapping: ", cluster_mapping)
    
    total_count, currect_count, wrong_count = 0, 0, 0
    both_right, both_wrong = 0, 0
    confusion_matrix = {"TP": 0, "TN": 0, "FP": 0, "FN": 0}

    original_one_final_two = []
    original_two_final_one = []
    
    for i in tqdm.tqdm(range(len(embeddings))):
        for j in range(i+1, len(embeddings)):
            total_count += 1
            if (embeddings[i][1] == embeddings[j][1]):
                if (clusters_result[i] == clusters_result[j]):
                    both_right += 1
                    currect_count += 1
                    confusion_matrix["TP"] += 1
                else:
                    #original_one_final_two.append([embeddings[i][0], embeddings[j][0], embeddings[i][1], clusters_result[i], clusters_result[j]])
                    wrong_count += 1
                    confusion_matrix["FN"] += 1
            else:
                if (clusters_result[i] != clusters_result[j]):
                    both_wrong += 1
                    currect_count += 1
                    confusion_matrix["TN"] += 1
                else:
                    #original_two_final_one.append([embeddings[i][0], embeddings[j][0], embeddings[i][1], embeddings[j][1], clusters_result[i]])
                    wrong_count += 1
                    confusion_matrix["FP"] += 1
                    
    print("total_count = {}, currect_count = {}, wrong_count = {}, both_right = {}, both_wrong = {}".format(total_count, currect_count, wrong_count, both_right, both_wrong))
    print(confusion_matrix)
    precision = float(format(confusion_matrix["TP"] / (confusion_matrix["TP"] + confusion_matrix["FP"]), ".3f"))
    recall = float(format(confusion_matrix["TP"] / (confusion_matrix["TP"] + confusion_matrix["FN"]), ".3f"))
    f1_score = float(format(2 * (precision * recall) / (precision + recall), ".3f"))
    accuracy = float(format(currect_count/total_count, ".3f"))
    print("Precision: {}, Recall: {} and F1-Score: {}".format(precision, recall, f1_score))
    print("Accuracy: {}".format(accuracy))
    
    return precision, recall, f1_score, accuracy

In [46]:
precision, recall, f1_score, accuracy = cluster_and_compare(gnn_embeddings, ground_truth_cluster_number = 5, clustering_algorithm = "Birch")

Cluster Counts:  {1: 1563, 0: 38, 2: 41, 3: 3, 4: 1}
Project Names:  ['Calendar.getTime', 'ClassLoader.getResource', 'DateFormat.format', 'BufferedReader.read', 'ExecutorService.submit']
Cluster Mapping:  {0: {'BufferedReader.read': 31, 'Calendar.getTime': 1, 'DateFormat.format': 6}, 1: {'BufferedReader.read': 392, 'Calendar.getTime': 533, 'ClassLoader.getResource': 91, 'DateFormat.format': 528, 'ExecutorService.submit': 19}, 2: {'BufferedReader.read': 40, 'ClassLoader.getResource': 1}, 3: {'BufferedReader.read': 3}, 4: {'BufferedReader.read': 1}}


100%|██████████| 1646/1646 [00:01<00:00, 902.62it/s] 

total_count = 1353835, currect_count = 459958, wrong_count = 893877, both_right = 363071, both_wrong = 96887
{'TP': 363071, 'TN': 96887, 'FP': 859158, 'FN': 34719}
Precision: 0.297, Recall: 0.913 and F1-Score: 0.448
Accuracy: 0.34





# Evalute Using CodeBERT/UnixCoder

In [26]:
import gc

#Set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

# Initialize the models
codebert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
codebert_model = codebert_model.to(device)
max_source_length= 512

def get_code_embeddings_from_codebert(codelines):
    gc.collect()
    torch.cuda.empty_cache()
    code = " ".join(codelines)
    source_tokens = codebert_tokenizer.tokenize(code)[:max_source_length-2]
    source_tokens = [codebert_tokenizer.cls_token]+source_tokens+[codebert_tokenizer.sep_token]
    source_ids =  codebert_tokenizer.convert_tokens_to_ids(source_tokens) 
    padding_length = max_source_length - len(source_ids)
    source_ids+=[codebert_tokenizer.pad_token_id]*padding_length
    source_ids = torch.tensor(source_ids)
    
    # tokens = []
    # for code_line in codelines:
    #     code_tokens = codebert_tokenizer.tokenize(code_line, truncation=True, max_length=510)
    #     if tokens == []:
    #         tokens = [codebert_tokenizer.cls_token] + code_tokens
    #     else:
    #         tokens = tokens + [codebert_tokenizer.sep_token] + code_tokens
    # tokens = tokens + [codebert_tokenizer.eos_token]
    # tokens_ids = torch.tensor(codebert_tokenizer.convert_tokens_to_ids(tokens))
    source_ids = source_ids.to(device)
    context_embeddings = codebert_model(source_ids[None,:])
    cls_token_embedding = context_embeddings.last_hidden_state[0,0,:]
    return cls_token_embedding

In [27]:
import gc
from unixcoder import UniXcoder

#Set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

# Initialize the models
unixcoder_model = UniXcoder("microsoft/unixcoder-base")
unixcoder_model = unixcoder_model.to(device)
max_source_length= 512

def get_code_embeddings_from_unixcoder(codelines):
    gc.collect()
    torch.cuda.empty_cache()
    code = " ".join(codelines)
    tokens_ids = unixcoder_model.tokenize([code], max_length=512, mode="<encoder-only>")
    source_ids = torch.tensor(tokens_ids).to(device)
    tokens_embeddings, code_embedding = unixcoder_model(source_ids)
    return torch.flatten(code_embedding)

In [34]:
def get_embeddings_from_llms(folders, model):
  embeddings = []
  for label, folder in tqdm.tqdm(enumerate(folders)):
    folder_name = folder.strip().split("/")[-1]
    print("\nProcessing: {}\n".format(folder_name))
    files = glob.glob(os.path.join(folder, '*/*.java'))
    print("\nNumber of files: {}\n".format(len(files)))
    count = 1
    for file in files:
      sample_name = file.split("/")[-2].strip()
      file_name = file.split("/")[-1].strip()
      if(count % 5 == 0):
          print("\nAt file: {}\n".format(count))
                        
      fp = open(file,'r')
      lines = fp.readlines()
      lines = [line for line in lines if not line.startswith("import") and not len(line.strip('\n')) == 0]
      lines = [line.strip('\n').strip("\t").strip(" ") for line in lines]
      if model == "codebert":
        embedding = get_code_embeddings_from_codebert(lines)
      elif model == "unixcoder":
        embedding = get_code_embeddings_from_unixcoder(lines)
      embedding = embedding.detach().cpu().numpy()
      embeddings.append([file_name, folder_name, embedding])
      count += 1
    
  return embeddings

In [41]:
project_folders_2 = ["/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/BufferedReader.read",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/Calendar.getTime"]

project_folders_5 = ["/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/BufferedReader.read",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/Calendar.getTime",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/ClassLoader.getResource",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/DateFormat.format",
                   "/home/siddharthsa/cs21mtech12001-Tamal/API-Misuse-Prediction/PDG-gen/Repository/Code_kernel_data/after_preprocessing/FINAL/ExecutorService.submit"]

In [42]:
with torch.no_grad():
    codebert_embeddings = get_embeddings_from_llms(project_folders_5, "codebert")

precision, recall, f1_score, accuracy = cluster_and_compare(codebert_embeddings, ground_truth_cluster_number = 5, clustering_algorithm = "Birch")

0it [00:00, ?it/s]


Processing: BufferedReader.read


Number of files: 472


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280


At file: 285


At file: 290


At file: 295


At file: 300


At file: 305


At file: 310


At file: 315


At file: 320


At

1it [02:12, 132.56s/it]


Processing: Calendar.getTime


Number of files: 557


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280


At file: 285


At file: 290


At file: 295


At file: 300


At file: 305


At file: 310


At file: 315


At file: 320


At fi

2it [04:49, 146.67s/it]


Processing: ClassLoader.getResource


Number of files: 103


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100



3it [05:22, 94.93s/it] 


Processing: DateFormat.format


Number of files: 562


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250



Token indices sequence length is longer than the specified maximum sequence length for this model (913 > 512). Running this sequence through the model will result in indexing errors



At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280


At file: 285


At file: 290


At file: 295


At file: 300


At file: 305


At file: 310


At file: 315


At file: 320


At file: 325


At file: 330


At file: 335


At file: 340


At file: 345


At file: 350


At file: 355


At file: 360


At file: 365


At file: 370


At file: 375


At file: 380


At file: 385


At file: 390


At file: 395


At file: 400


At file: 405


At file: 410


At file: 415


At file: 420


At file: 425


At file: 430


At file: 435


At file: 440


At file: 445


At file: 450


At file: 455


At file: 460


At file: 465


At file: 470


At file: 475


At file: 480


At file: 485


At file: 490


At file: 495


At file: 500


At file: 505


At file: 510


At file: 515


At file: 520


At file: 525


At file: 530


At file: 535


At file: 540


At file: 545


At file: 550


At file: 555


At file: 560



4it [08:20, 127.54s/it]


Processing: ExecutorService.submit


Number of files: 21


At file: 5


At file: 10


At file: 15


At file: 20



5it [08:27, 101.54s/it]


Cluster Counts:  {0: 528, 4: 498, 3: 84, 2: 18, 1: 587}
Project Names:  ['ExecutorService.submit', 'BufferedReader.read', 'ClassLoader.getResource', 'DateFormat.format', 'Calendar.getTime']
Cluster Mapping:  {0: {'BufferedReader.read': 252, 'Calendar.getTime': 135, 'ClassLoader.getResource': 47, 'DateFormat.format': 89, 'ExecutorService.submit': 5}, 1: {'BufferedReader.read': 14, 'Calendar.getTime': 201, 'ClassLoader.getResource': 23, 'DateFormat.format': 342, 'ExecutorService.submit': 7}, 2: {'BufferedReader.read': 17, 'DateFormat.format': 1}, 3: {'BufferedReader.read': 65, 'Calendar.getTime': 5, 'ClassLoader.getResource': 3, 'DateFormat.format': 10, 'ExecutorService.submit': 1}, 4: {'BufferedReader.read': 124, 'Calendar.getTime': 216, 'ClassLoader.getResource': 30, 'DateFormat.format': 120, 'ExecutorService.submit': 8}}


100%|██████████| 1715/1715 [00:01<00:00, 964.56it/s] 

total_count = 1469755, currect_count = 932492, wrong_count = 537263, both_right = 165177, both_wrong = 767315
{'TP': 165177, 'TN': 767315, 'FP': 273334, 'FN': 263929}
Precision: 0.377, Recall: 0.385 and F1-Score: 0.381
Accuracy: 0.634





In [43]:
with torch.no_grad():
    unixcoder_embeddings = get_embeddings_from_llms(project_folders_5, "unixcoder")

precision, recall, f1_score, accuracy = cluster_and_compare(unixcoder_embeddings, ground_truth_cluster_number = 5, clustering_algorithm = "Birch")

0it [00:00, ?it/s]


Processing: BufferedReader.read


Number of files: 472


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280


At file: 285


At file: 290


At file: 295


At file: 300


At file: 305


At file: 310


At file: 315


At file: 320


At

1it [01:51, 111.41s/it]


Processing: Calendar.getTime


Number of files: 557


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280


At file: 285


At file: 290


At file: 295


At file: 300


At file: 305


At file: 310


At file: 315


At file: 320


At fi

2it [04:00, 121.86s/it]


Processing: ClassLoader.getResource


Number of files: 103


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100



3it [04:24, 77.28s/it] 


Processing: DateFormat.format


Number of files: 562


At file: 5


At file: 10


At file: 15


At file: 20


At file: 25


At file: 30


At file: 35


At file: 40


At file: 45


At file: 50


At file: 55


At file: 60


At file: 65


At file: 70


At file: 75


At file: 80


At file: 85


At file: 90


At file: 95


At file: 100


At file: 105


At file: 110


At file: 115


At file: 120


At file: 125


At file: 130


At file: 135


At file: 140


At file: 145


At file: 150


At file: 155


At file: 160


At file: 165


At file: 170


At file: 175


At file: 180


At file: 185


At file: 190


At file: 195


At file: 200


At file: 205


At file: 210


At file: 215


At file: 220


At file: 225


At file: 230


At file: 235


At file: 240


At file: 245


At file: 250


At file: 255


At file: 260


At file: 265


At file: 270


At file: 275


At file: 280


At file: 285


At file: 290


At file: 295


At file: 300


At file: 305


At file: 310


At file: 315


At file: 320


At f

4it [06:40, 100.48s/it]


Processing: ExecutorService.submit


Number of files: 21


At file: 5


At file: 10


At file: 15


At file: 20



5it [06:46, 81.24s/it] 


Cluster Counts:  {4: 245, 1: 287, 0: 675, 2: 436, 3: 72}
Project Names:  ['ExecutorService.submit', 'BufferedReader.read', 'ClassLoader.getResource', 'DateFormat.format', 'Calendar.getTime']
Cluster Mapping:  {0: {'BufferedReader.read': 2, 'Calendar.getTime': 470, 'ClassLoader.getResource': 3, 'DateFormat.format': 199, 'ExecutorService.submit': 1}, 1: {'BufferedReader.read': 225, 'Calendar.getTime': 2, 'ClassLoader.getResource': 28, 'DateFormat.format': 12, 'ExecutorService.submit': 20}, 2: {'Calendar.getTime': 85, 'DateFormat.format': 351}, 3: {'ClassLoader.getResource': 72}, 4: {'BufferedReader.read': 245}}


100%|██████████| 1715/1715 [00:01<00:00, 971.17it/s] 

total_count = 1469755, currect_count = 1151249, wrong_count = 318506, both_right = 253196, both_wrong = 898053
{'TP': 253196, 'TN': 898053, 'FP': 142596, 'FN': 175910}
Precision: 0.64, Recall: 0.59 and F1-Score: 0.614
Accuracy: 0.783



