N-Gram

In [None]:
#!pip3 install pydriller javalang nltk pygments pandas numpy


In [18]:
import os
import pandas as pd
import numpy as np
import re
from nltk.lm import MLE, Laplace
from nltk.lm.preprocessing import padded_everygram_pipeline, everygrams
from nltk.tokenize import word_tokenize
from pydriller import Repository
import javalang
import random
import csv

from javalang.parse import parse
from javalang.tree import MethodDeclaration

from pygments.lexers.jvm import JavaLexer
from pygments.lexers import get_lexer_by_name
from pygments.token import Token



### Step 1: Extract Java Methods from Repositories
1. Use PyDriller to traverse commits in the master branch of each repo.
2. Extract method names and their source code using javalang.
3. Store the extracted data in a CSV file.

In [19]:
repo_names = """
dustin/java-memcached-client
davidb/scala-maven-plugin
cyberfox/jbidwatcher
"""

repoList = ["https://www.github.com/" + repo for repo in repo_names.strip().split('\n')]
repoList[0:5]



['https://www.github.com/dustin/java-memcached-client',
 'https://www.github.com/davidb/scala-maven-plugin',
 'https://www.github.com/cyberfox/jbidwatcher']

In [23]:
# ------------------------------- DATA EXTRACTION -------------------------------
def extract_methods_from_java(code):
    """
    Extract methods from Java source code using javalang parser.

    Args:
        code (str): The Java source code.

    Returns:
        list: A list of tuples containing method names and their full source code.
    """
    methods = []
    try:
        # Parse the code into an Abstract Syntax Tree (AST)
        tree = javalang.parse.parse(code)
        lines = code.splitlines()

        # Traverse the tree to find method declarations
        for _, node in tree.filter(javalang.tree.MethodDeclaration):
            method_name = node.name

            # Determine the start and end lines of the method
            start_line = node.position.line - 1
            end_line = None

            # Use the body of the method to determine its end position
            if node.body:
                last_statement = node.body[-1]
                if hasattr(last_statement, 'position') and last_statement.position:
                    end_line = last_statement.position.line

            # Extract method code
            if end_line:
                method_code = "\n".join(lines[start_line:end_line+1])
            else:
                # If end_line couldn't be determined, extract up to the end of the file
                method_code = "\n".join(lines[start_line:])

            methods.append((method_name, method_code))

    except Exception as e:
        print(f"Error parsing Java code: {e}")
    return methods

def extract_methods_to_csv_from_master(repo_path, output_csv):
    """
    Extract methods from Java files in the master branch and save them in a CSV file.

    Args:
        repo_path (str): Path to the Git repository.
        output_csv (str): Path to the output CSV file.
    """
    with open(output_csv, mode='w', newline='', encoding='utf-8') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(["Commit_Hash", "File_Name", "Method_Name", "Method_Code", "Commit_Link"])

        for commit in Repository(repo_path, only_in_branch="master").traverse_commits():
            print(f"Processing commit: {commit.hash}")

            #We only look into the modified files. In other words, we are looking into the history of the software system by traversing each commit.
            #Various Generative AI methods for SD have been trained on data collected in this way; for example bug fixing.
            for modified_file in commit.modified_files:
                if modified_file.filename.endswith(".java") and modified_file.source_code:
                    methods = extract_methods_from_java(modified_file.source_code)

                    for method_name, method_code in methods:
                        commit_link = f"{repo_path}/commit/{commit.hash}"
                        csv_writer.writerow([commit.hash, modified_file.filename, method_name, method_code, commit_link])

                    print(f"Extracted methods from {modified_file.filename} in commit {commit.hash}")

def extract_methods_to_csv(repo_path, output_csv):
    """
    Extract methods from Java files in a repository and save them in a CSV file.

    Args:
        repo_path (str): Path to the Git repository.
        output_csv (str): Path to the output CSV file.
    """
    with open(output_csv, mode='w', newline='', encoding='utf-8') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(["Branch_Name", "Commit_Hash", "File_Name", "Method_Name", "Method_Code", "Commit_Link"])

        branch_name = "master"
        for commit in Repository(repo_path, only_in_branch=branch_name).traverse_commits():
            print(f"Processing commit: {commit.hash}")

            for modified_file in commit.modified_files:
                if modified_file.filename.endswith(".java") and modified_file.source_code:
                    methods = extract_methods_from_java(modified_file.source_code)

                    for method_name, method_code in methods:
                        commit_link = f"{repo_path}/commit/{commit.hash}"
                        csv_writer.writerow([branch_name, commit.hash, modified_file.filename, method_name, method_code, commit_link])

                    print(f"Extracted methods from {modified_file.filename} in commit {commit.hash}")



for repo in repoList[0:1]:

    fileNameToSave = ''.join(repo.split('github.com')[1:])
    fileNameToSave = fileNameToSave.replace('/','_')

    # Specify the path to the output CSV file
    output_csv_file = "extracted_methods_{}.csv".format(fileNameToSave)
    # Run the extraction
    extract_methods_to_csv_from_master(repo, output_csv_file)


    print(repo)


Processing commit: c0772446a20d62bcaf7104527c0ea512b817a7b1
Processing commit: afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from MemcachedClient.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from MemcachedConnection.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from MemcachedTestClient.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from DeleteOperation.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from FlushOperation.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from GetOperation.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from MutatorOperation.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from Operation.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from StatsOperation.java in commit afe8e17710e8dc5993aadaf6043afdb281a1cbdd
Extracted methods from Stor

### Step 2: Preprocess the Extracted Data
1. Remove duplicates (Type-1 Clones)
2. Filter ASCII methods (ensure only valid characters are retained)
3. Remove outliers based on method length (5th-95th percentile filtering)
4. Remove boilerplate methods (setters/getters)
5. Remove comments using Pygments lexer

In [26]:
# ------------------------------- DATA CLEANING -------------------------------
### Type 1 Clones ###
def remove_duplicates(data):
    """Remove duplicate methods based on method content.
      Almost Type-1 with the exception of comments
    """
    return data.drop_duplicates(subset="Method_Code", keep="first")

def filter_ascii_methods(data):
    """Filter methods to include only those with ASCII characters."""
    data = data[data["Method_Code"].apply(lambda x: all(ord(char) < 128 for char in x))]
    return data

# Three Approaches:
# 	1.	Data Distribution-Based Filtering: We eliminate outliers by analyzing the original data distribution, as demonstrated below.
# 	2.	Literature-Driven Filtering: We follow best practices outlined in research, such as removing methods exceeding 512 tokens in length.
# 	3.	Hybrid Approach: We combine elements from both the distribution-based and literature-driven methods.

def remove_outliers(data, lower_percentile=5, upper_percentile=95):
    """Remove outliers based on method length."""
    method_lengths = data["Method_Code"].apply(len)
    lower_bound = method_lengths.quantile(lower_percentile / 100)
    upper_bound = method_lengths.quantile(upper_percentile / 100)
    return data[(method_lengths >= lower_bound) & (method_lengths <= upper_bound)]


def remove_boilerplate_methods(data):
    """Remove boilerplate methods like setters and getters."""
    boilerplate_patterns = [
        r"\bset[A-Z][a-zA-Z0-9_]*\s*{",  # Setter methods
        r"\bget[A-Z][a-zA-Z0-9_]*\s*{",  # Getter methods
    ]
    boilerplate_regex = re.compile("|".join(boilerplate_patterns))
    data = data[~data["Method_Code"].apply(lambda x: bool(boilerplate_regex.search(x)))]
    return data

def remove_comments_from_dataframe(df: pd.DataFrame, method_column: str, language: str) -> pd.DataFrame:
    """
    Removes comments from Java methods in a DataFrame and adds a new column with cleaned methods.

    Args:
        df (pd.DataFrame): DataFrame containing the methods.
        method_column (str): Column name containing the raw Java methods.
        language (str): Programming language for the lexer (e.g., 'java').

    Returns:
        pd.DataFrame: Updated DataFrame with a new column 'Java Method No Comments'.
    """
    # Define a function to remove comments from a single method
    def remove_comments(code):
        lexer = get_lexer_by_name(language)
        tokens = lexer.get_tokens(code)
        # Filter out comments using a lambda function
        clean_code = ''.join(token[1] for token in tokens if not (lambda t: t[0] in Token.Comment)(token))


        return clean_code

    # Apply the function to the specified column and add a new column with the results
    df["Method_Java_No_Comments"] = df[method_column].apply(remove_comments)
    return df


def clean_methods(data):
    data = remove_duplicates(data)
    data = filter_ascii_methods(data)
    data = remove_outliers(data)
    data = remove_boilerplate_methods(data)
    data = remove_comments_from_dataframe(data, "Method_Code", "java")
    return data

# Load the extracted methods from the CSV file
data = pd.read_csv("extracted_methods__dustin_java-memcached-client.csv")
# Clean the data
data = clean_methods(data)
# Display the cleaned data
data.head(3)



Unnamed: 0,Commit_Hash,File_Name,Method_Name,Method_Code,Commit_Link,Method_Java_No_Comments
0,afe8e17710e8dc5993aadaf6043afdb281a1cbdd,MemcachedClient.java,storeAsync,\tpublic void storeAsync(StoreOperation.StoreT...,https://www.github.com/dustin/java-memcached-c...,\tpublic void storeAsync(StoreOperation.StoreT...
1,afe8e17710e8dc5993aadaf6043afdb281a1cbdd,MemcachedClient.java,storeSync,\tpublic String storeSync(StoreOperation.Store...,https://www.github.com/dustin/java-memcached-c...,\tpublic String storeSync(StoreOperation.Store...
2,afe8e17710e8dc5993aadaf6043afdb281a1cbdd,MemcachedClient.java,storeResult,\t\t\t\t\tpublic void storeResult(String val) ...,https://www.github.com/dustin/java-memcached-c...,\t\t\t\t\tpublic void storeResult(String val) ...


### Step 3: Train n-gram Models for Code Completion
1. Tokenize the Java method source code using Pygments.
2. Build n-gram language models for n=3 and n=5 using NLTK or KenLM.
3. Train on extracted and cleaned methods.

In [32]:
# ------------------------------- N-GRAM MODELING -------------------------------
def tokenize_java_code(code):
    """Tokenize Java code using Pygments lexer."""
    lexer = JavaLexer()
    tokens = [t[1] for t in lexer.get_tokens(code) if t[0] not in Token.Comment]  # Exclude comments
    return tokens


def train_ngram_model(methods, n, smoothing="laplace"):
    tokenized_methods = [tokenize_java_code(method) for method in methods]
    train_data, vocab = padded_everygram_pipeline(n, tokenized_methods)
    
    model = Laplace(n) if smoothing == "laplace" else MLE(n)
    model.fit(train_data, vocab)
    return model

### Step 4: Evaluate and Select the Best Model
1. Test both models on 100 Java methods.
2. Measure perplexity (lower is better) or accuracy (prediction correctness).
3. Select the best-performing model for final use.


In [33]:
def calculate_perplexity(model, test_methods, n):
    test_tokenized = [tokenize_java_code(method) for method in test_methods]
    test_ngrams = [list(everygrams(tokens, max_len=n)) for tokens in test_tokenized]
    valid_ngrams = [ngrams for ngrams in test_ngrams if len(ngrams) > 0]

    if not valid_ngrams:
        return float("inf")

    perplexities = [model.perplexity(ngrams) for ngrams in valid_ngrams]
    return np.mean(perplexities)

In [34]:
# ------------------------------- EXECUTION -------------------------------
output_csv = "methods.csv"
methods = data["Method_Code"].tolist()

# Step 2: Train N-Gram Models & Evaluate Perplexity
n_values = [3, 5]
best_n = None
best_perplexity = float("inf")
perplexities = {}

for n in n_values:
    model = train_ngram_model(methods, n)
    perplexity = calculate_perplexity(model, methods[:100], n)  # Using 100 Java methods for testing
    perplexities[n] = perplexity

    print(f"n={n} Perplexity: {perplexity}")
    if perplexity < best_perplexity:
        best_n = n
        best_perplexity = perplexity

print(f"\nBest model: n={best_n} with perplexity={best_perplexity}")

# Step 3: Generate Code vs. Ground Truth Example
def generate_code(model, start_tokens, length=20):
    """Generates code based on the trained N-Gram model."""
    generated_tokens = start_tokens[:]
    for _ in range(length):
        next_word = model.generate(text_seed=generated_tokens[-2:])  # Use last two words as seed
        if next_word is None:
            break
        generated_tokens.append(next_word)
    return " ".join(generated_tokens)

if best_n:
    best_model = train_ngram_model(methods, best_n)
    example_method = random.choice(methods)
    start_tokens = tokenize_java_code(example_method)[:2]  # Take first two words as seed
    generated_code = generate_code(best_model, start_tokens, length=20)

    print("\n==== CODE COMPLETION EXAMPLE ====")
    print("ðŸŽ¯ Ground Truth:\n", example_method[:200], "...")
    print("ðŸ¤– Generated Code:\n", generated_code)


n=3 Perplexity: 43.40415284876779
n=5 Perplexity: 60.023154987730514

Best model: n=3 with perplexity=43.40415284876779

==== CODE COMPLETION EXAMPLE ====
ðŸŽ¯ Ground Truth:
   public V remove(Object key) {
    V rv = null;
    try {
      rv = get(key);
      client.delete(getKey((String) key));
    } catch (ClassCastException e) {
      // Not a string key. Ignore.
    } ...
ðŸ¤– Generated Code:
    public   void   testTapBucketDoesNotExist ( ) ; 
 
    boolean   waitForQueues ( long   def ) ; 

