Notebook created by Rosa Filgueira - r.filgueira@epcc.ed.ac.uk

# Introduction:

This notebook demonstrates the use of transformer models to embed and compare Java code snippets based on their semantic similarity.

Additionally, it explores parsing Java code to extract key elements such as classes, methods, and associated comments for further analysis.

# Table of Contents:

- Java Snippets Embeddings and Similarity
  - Overview of encoding Java snippets into embeddings for semantic similarity analysis.
  - Use of transformer models (UniXcoder) for code understanding.
- Parsing Java Code
  - Extracting classes, methods, fields, and comments from Java files using the javalang library.



## Java Snippets Embeddings and Similarity


The concept of embedding Java snippets involves converting pieces of Java code (e.g., methods, functions, or classes) into dense, numerical vectors. These embeddings capture the semantic meaning and functionality of the code, allowing us to perform similarity analysis. By comparing these embeddings, we can determine how closely related two code snippets are in terms of their logic and functionality, even if the code structure or variable names differ.

**Key Steps in the Process**:

- **Tokenization**: The Java code is tokenized into smaller components, such as keywords, identifiers, and symbols, which serve as input for the transformer model.
- **Embedding Generation**: A pretrained transformer model like UniXcoder is used to process the tokenized code and generate embeddings. These embeddings represent the underlying functionality and intent of the code.

- **Similarity Calculation**: The embeddings of two or more code snippets are compared using metrics like cosine similarity. A high similarity score indicates that the snippets perform similar tasks or have related functionality.

In [None]:
!pip install transformers torch



In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np

# Load UniXcoder
## Base one
#tokenizer = AutoTokenizer.from_pretrained("microsoft/unixcoder-base")
#model = AutoModel.from_pretrained("microsoft/unixcoder-base")

### The Unixcoder version used in reposim4py - it gives us a slightly better result with this example.
tokenizer = AutoTokenizer.from_pretrained("Lazyhope/unixcoder-nine-advtest")
model = AutoModel.from_pretrained("Lazyhope/unixcoder-nine-advtest")



# Define a helper function to normalize embeddings
def normalize_embeddings(embeddings):
    return embeddings / torch.norm(embeddings, dim=1, keepdim=True)

# Encode Java code snippets
def encode_code(code):
    tokens = tokenizer(code, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        embeddings = model(**tokens).last_hidden_state.mean(dim=1)  # Mean pooling
    return embeddings

# Cosine similarity calculation
def cosine_similarity(embedding1, embedding2):
    embedding1 = normalize_embeddings(embedding1)
    embedding2 = normalize_embeddings(embedding2)
    return torch.matmul(embedding1, embedding2.T).item()

# Example Java code snippets
code_snippet_1 = """public int add(int a, int b) { return a + b; }"""
code_snippet_2 = """public int sum(int x, int y) { return x + y; }"""

# Encode the code snippets
embedding1 = encode_code(code_snippet_1)
embedding2 = encode_code(code_snippet_2)

# Calculate similarity
similarity = cosine_similarity(embedding1, embedding2)
print(f"Cosine Similarity: {similarity}")





The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/444k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/957 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/504M [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Cosine Similarity: 0.6810680627822876


In [None]:

# Example Java - this code snippet is exactly the same as code_snippet_1
code_snippet_3 = """public int add(int a, int b) { return a + b; }"""
# Encode the code snippets
embedding1 = encode_code(code_snippet_1)
embedding3 = encode_code(code_snippet_3)

# Calculate similarity
similarity = cosine_similarity(embedding1, embedding3)
print(f"Cosine Similarity: {similarity}")


Cosine Similarity: 1.000000238418579


## Parsing Java Code

This section demonstrates how to extract key elements from Java code using the **javalang** library:

* Installation:
Install the javalang library using pip.

* Code Example:
A sample Java file is created with classes, methods, fields, and documentation comments.

* Parsing Process:
  * Parse the Java code to extract:
    * Classes: Identifies class names, methods, fields, and associated documentation.
    * Methods: Extracts method names, parameters, return types, **code, and comments**.

* Output:
A detailed breakdown of i

In [None]:
!pip install javalang

Collecting javalang
  Downloading javalang-0.13.0-py3-none-any.whl.metadata (805 bytes)
Downloading javalang-0.13.0-py3-none-any.whl (22 kB)
Installing collected packages: javalang
Successfully installed javalang-0.13.0


In [None]:
# Write example Java code to a file
java_code = """
// Example.java

// Import statements
import java.util.List;
import java.util.ArrayList;

// A simple utility class
public class Example {
    // A private field
    private String name;

    // Constructor
    public Example(String name) {
        this.name = name;
    }

    /**
     * A method to get the name.
     * @return the name of the object.
     */
    public String getName() {
        return this.name;
    }

    /**
     * A method to sum two integers.
     * @param a the first integer.
     * @param b the second integer.
     * @return the sum of a and b.
     */
    public int add(int a, int b) {
        return a + b;
    }

    /**
     * A method that returns a list of strings.
     * @return a list of sample strings.
     */
    public List<String> getSampleList() {
        List<String> samples = new ArrayList<>();
        samples.add("Sample 1");
        samples.add("Sample 2");
        return samples;
    }
}

// A utility class for mathematical and printing functions
class Utility {
    /**
     * A method to print a welcome message.
     */
    public void printWelcome() {
        System.out.println("Welcome to the utility method demonstration!");
    }

    /**
     * A method to calculate the square of a number.
     * @param num the number to square.
     * @return the square of the number.
     */
    public int square(int num) {
        return num * num;
    }
}

// Another helper class
class Helper {
    /**
     * A helper method to print a message.
     * @param message the message to print.
     */
    public void printMessage(String message) {
        System.out.println("Message: " + message);
    }
}
"""

# Save it as example_java.java
with open("example_java.java", "w") as file:
    file.write(java_code)

In [None]:
!cat example_java.java


// Example.java

// Import statements
import java.util.List;
import java.util.ArrayList;

// A simple utility class
public class Example {
    // A private field
    private String name;

    // Constructor
    public Example(String name) {
        this.name = name;
    }

    /**
     * A method to get the name.
     * @return the name of the object.
     */
    public String getName() {
        return this.name;
    }

    /**
     * A method to sum two integers.
     * @param a the first integer.
     * @param b the second integer.
     * @return the sum of a and b.
     */
    public int add(int a, int b) {
        return a + b;
    }

    /**
     * A method that returns a list of strings.
     * @return a list of sample strings.
     */
    public List<String> getSampleList() {
        List<String> samples = new ArrayList<>();
        samples.add("Sample 1");
        samples.add("Sample 2");
        return samples;
    }
}

// A utility class for mathematical and printing functi

In [None]:
import javalang

def extract_code(file_path, start_line, end_line):
    """Extract code from a Java file given start and end lines."""
    with open(file_path, 'r') as file:
        lines = file.readlines()
    return ''.join(lines[start_line - 1:end_line]) if start_line and end_line else ""

def extract_preceding_line_comment(lines, start_line):
    """Extract single-line comments (`//`) directly preceding a line of code."""
    comment_lines = []
    for i in range(start_line - 2, -1, -1):  # Traverse upwards from the start line
        line = lines[i].strip()
        if line.startswith("//"):
            comment_lines.insert(0, line[2:].strip())  # Remove `//` and strip whitespace
        elif line:  # Stop at non-empty, non-comment line
            break
    return " ".join(comment_lines) if comment_lines else "No documentation provided."

def parse_java_file(file_path):
    """Parse a Java file and extract imports, classes, methods, and their code."""
    with open(file_path, 'r') as file:
        code = file.read()

    # Parse the code using javalang
    tree = javalang.parse.parse(code)

    # Extract classes and their methods
    classes = []
    with open(file_path, 'r') as file:
        lines = file.readlines()

    for path, node in tree:
        if isinstance(node, javalang.tree.ClassDeclaration):
            # Get the class code
            class_start = node.position.line if node.position else None
            class_end = None
            if node.position and hasattr(node, 'body') and node.body:
                class_end = max(
                    (child.position.line for child in node.body if child.position),
                    default=None
                )
            class_code = extract_code(file_path, class_start, class_end)

            class_info = {
                'name': node.name,
                'fields': [
                    variable.name
                    for field in node.body
                    if isinstance(field, javalang.tree.FieldDeclaration)
                    for variable in field.declarators
                ],
                'documentation': extract_preceding_line_comment(lines, class_start),
                'methods': [],
                'code': class_code
            }

            # Extract methods within the class
            for member in node.body:
                if isinstance(member, javalang.tree.MethodDeclaration):
                    method_start = member.position.line if member.position else None
                    method_end = None
                    if member.position and member.body:
                        method_end = max(
                            (statement.position.line for statement in member.body if statement.position),
                            default=None
                        )
                    method_code = extract_code(file_path, method_start, method_end)

                    method_info = {
                        'name': member.name,
                        'parameters': [
                            f"{param.type.name} {param.name}" for param in member.parameters
                        ],
                        'return_type': member.return_type.name if member.return_type else "void",
                        'documentation': extract_preceding_line_comment(lines, method_start),
                        'is_static': 'static' in member.modifiers,
                        'code': method_code
                    }
                    class_info['methods'].append(method_info)

            classes.append(class_info)

    return {
        'classes': classes
    }


In [None]:
# Example usage
java_file_path = './example_java.java'  # Path to your Java file
result = parse_java_file(java_file_path)



# Print classes and their methods
print("\nClasses:")
for cls in result['classes']:
    print(f"Class: {cls['name']}")
    print(f"  Documentation: {cls['documentation']}")
    #print(f"  Fields: {', '.join(cls['fields']) if cls['fields'] else 'No fields'}")
    #print(f"  Code:\n{cls['code']}")
    for method in cls['methods']:
        print(f"  Method: {method['name']}")
        print(f"    Parameters: {', '.join(method['parameters']) if method['parameters'] else 'None'}")
        print(f"    Return Type: {method['return_type']}")
        print(f"    Documentation: {method['documentation']}")
        print(f"    Code:\n{method['code']}")


Classes:
Class: Example
  Documentation: A simple utility class
  Method: getName
    Parameters: None
    Return Type: String
    Documentation: No documentation provided.
    Code:
    public String getName() {
        return this.name;

  Method: add
    Parameters: int a, int b
    Return Type: int
    Documentation: No documentation provided.
    Code:
    public int add(int a, int b) {
        return a + b;

  Method: getSampleList
    Parameters: None
    Return Type: List
    Documentation: No documentation provided.
    Code:
    public List<String> getSampleList() {
        List<String> samples = new ArrayList<>();
        samples.add("Sample 1");
        samples.add("Sample 2");
        return samples;

Class: Utility
  Documentation: A utility class for mathematical and printing functions
  Method: printWelcome
    Parameters: None
    Return Type: void
    Documentation: No documentation provided.
    Code:
    public void printWelcome() {
        System.out.println("Welc