# **Step 1:** install required Dependencies

In [1]:
pip install -q accelerate==0.34.2 peft==0.6.2 bitsandbytes transformers trl==0.9.6 torch datasets

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install tensorboardX

Note: you may need to restart the kernel to use updated packages.


# **Step 2:** import required packages

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import os 
import json
from datasets import load_dataset, Dataset

  from .autonotebook import tqdm as notebook_tqdm


# **Step 3:** define required functions

In [4]:
## Load the fine-tuned CodeLLaMA 7B model with 4-bit quantization
def load_model_and_tokenizer(model_path):
    print("Loading model and tokenizer with 4-bit quantization...")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # Configure quantization
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        use_safetensors=True,
        quantization_config=quant_config,
        device_map="cuda:0",
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)  # Move model to GPU or CPU based on availability
    tokenizer.pad_token = tokenizer.eos_token  # Ensure padding token is set
    print(f"Model and tokenizer loaded successfully on {device}.")
    return model, tokenizer, device

# Generate a response from the model
def generate_response(prompt, model, tokenizer, device, max_length=1024, temperature=0.7, top_p=0.9):
    print("Generating response...")
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)  # Move inputs to the same device as the model
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.pad_token_id
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("Response generated.")
    return response

# **Step 4:** response generation

In [5]:
# Specify the path to your fine-tuned model directory
model_path = "./resources/trained_model_adapt_param"

# Load the model and tokenizer
model, tokenizer, device = load_model_and_tokenizer(model_path)

Loading model and tokenizer with 4-bit quantization...


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.23s/it]
You shouldn't move a model that is dispatched using accelerate hooks.


Model and tokenizer loaded successfully on cuda.


In [6]:
# Test the model with some prompts
prompts = [
    "[INST]<<SYS>>Generate unit tests for the following method or function:\n<</SYS>> public Set<String> getOutputResourceFields( T meta ) { return null; } [/INST]",
    "[INST]<<SYS>>Generate unit tests for the following method or function:\n<</SYS>> @Override public Long queryFrom(MonetaryAmount amount) { Objects.requireNonNull(amount, \"Amount required.\"); return amount.with(downRounding).getNumber().longValueExact(); } [/INST] "
]

for i, prompt in enumerate(prompts):
    print(f"\nPrompt {i+1}: {prompt}")
    response = generate_response(prompt, model, tokenizer, device)
    print(f"Response {i+1}: {response}")


Prompt 1: [INST]<<SYS>>Generate unit tests for the following method or function:
<</SYS>> public Set<String> getOutputResourceFields( T meta ) { return null; } [/INST]
Generating response...




Response generated.
Response 1: [INST]<<SYS>>Generate unit tests for the following method or function:
<</SYS>> public Set<String> getOutputResourceFields( T meta ) { return null; } [/INST]
 @Test public void testGetOutputResourceFields() { assertNull( step.getOutputResourceFields( meta ) ); } 

 @Test public void testGetOutputResourceFields_Null() { assertNull( step.getOutputResourceFields( null ) ); } 

 @Test public void testGetOutputResourceFields_Empty() { assertNull( step.getOutputResourceFields( new TextFileInputMeta() ) ); }  @Test public void testGetOutputResourceFields_NotEmpty() { TextFileInputMeta meta = new TextFileInputMeta(); meta.setOutputFields( new TextFileInputField[] { new TextFileInputField() } ); assertNull( step.getOutputResourceFields( meta ) ); } 

 @Test public void testGetOutputResourceFields_NotEmpty_Null() { TextFileInputMeta meta = new TextFileInputMeta(); meta.setOutputFields( new TextFileInputField[] { new TextFileInputField() } ); meta.setOutputFields( 

# **Step 5:** generate test dataset

In [7]:
DATASET_NAME = "jitx/Methods2Test_java_unit_test_code"

training_dataset = load_dataset(DATASET_NAME, split="train")

# Set a seed for deterministic sorting
seed = 85
part = 0.00001

test_partion = training_dataset.shuffle(seed=seed).select(range(int(len(training_dataset) * part)))

In [8]:
INPUT_FIELD = "src_fm"
OUTPUT_FIELD = "target"

# Function to convert each example
def convert_to_llama_format(focal_method, target_test_case):
    # Define the system prompt
    system_prompt = "Generate unit tests for the following method or function:\n"

    # Format the example into LLaMA format
    formatted_example = f"<s>[INST]<<SYS>>{system_prompt}<</SYS>> {focal_method} [/INST]\n {target_test_case} </s>"

    return formatted_example

# Convert the entire dataset
converted_data = [{"text": convert_to_llama_format(entry[INPUT_FIELD], entry[OUTPUT_FIELD])} for entry in test_partion]

# Save the converted data to a JSON file
output_file = './resources/dataset/llama_format_dataset_test.json'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
    json.dump(converted_data, f, indent=4)

# Print a few examples to verify the result
print(f"Converted dataset Training saved to {output_file}")
for example in converted_data[:5]:
    print(example)

# Function to convert each example
def convert_to_llama_format_prompt(focal_method):
    # Define the system prompt
    system_prompt = "Generate unit tests for the following method or function:\n"

    # Format the example into LLaMA format
    formatted_example = f"[INST]<<SYS>>{system_prompt}<</SYS>> {focal_method} [/INST]"

    return formatted_example

# prompts
prompts = [convert_to_llama_format_prompt(entry[INPUT_FIELD]) for entry in test_partion]
print(prompts)


Converted dataset Training saved to ./codellama7b_finetuning/dataset/llama_format_dataset_test.json
{'text': '<s>[INST]<<SYS>>Generate unit tests for the following method or function:\n<</SYS>> @Override public String toSqlConstraint(String quoteString, DbProduct dbProduct) { if (quoteString == null) { throw new RuntimeException("Quote string cannot be null"); } return generateRangeConstraint( quoteString + column + quoteString, Stream.of(boundaries).map(b -> b == null ? null : b.toString()).toArray(String[]::new) ); } [/INST]\n @Test public void testLeftBounded() { IntPartition partition = new IntPartition(COL_RAW, 0L, null); String constraint = partition.toSqlConstraint(QUOTE, dbProduct); assertEquals(COL + " >= 0", constraint); } </s>'}
{'text': '<s>[INST]<<SYS>>Generate unit tests for the following method or function:\n<</SYS>> public Object invoke(Object controller, Context context) { Object[] arguments = new Object[argumentExtractors.length]; for (int i = 0; i < argumentExtractor

# **Step 6:** Test Model with the Test Dataset

In [9]:
for i, prompt in enumerate(prompts):
    print(f"\nPrompt {i+1}: {prompt}")
    response = generate_response(prompt, model, tokenizer, device)
    print(f"Response {i+1}: {response}")


Prompt 1: [INST]<<SYS>>Generate unit tests for the following method or function:
<</SYS>> @Override public String toSqlConstraint(String quoteString, DbProduct dbProduct) { if (quoteString == null) { throw new RuntimeException("Quote string cannot be null"); } return generateRangeConstraint( quoteString + column + quoteString, Stream.of(boundaries).map(b -> b == null ? null : b.toString()).toArray(String[]::new) ); } [/INST]
Generating response...




Response generated.
Response 1: [INST]<<SYS>>Generate unit tests for the following method or function:
<</SYS>> @Override public String toSqlConstraint(String quoteString, DbProduct dbProduct) { if (quoteString == null) { throw new RuntimeException("Quote string cannot be null"); } return generateRangeConstraint( quoteString + column + quoteString, Stream.of(boundaries).map(b -> b == null ? null : b.toString()).toArray(String[]::new) ); } [/INST]
 @Test public void testToSqlConstraint() { String quoteString = "\""; String column = "column"; String lowerBound = "lowerBound"; String upperBound = "upperBound"; String lowerBoundInclusive = "lowerBoundInclusive"; String upperBoundInclusive = "upperBoundInclusive"; String lowerBoundExclusive = "lowerBoundExclusive"; String upperBoundExclusive = "upperBoundExclusive"; String lowerBoundInclusiveInclusive = "lowerBoundInclusiveInclusive"; String upperBoundInclusiveInclusive = "upperBoundInclusiveInclusive"; String lowerBoundExclusiveInclusive =