"""
Author: Michael Drob
Contact: https://www.linkedin.com/in/michael-drob/
License: Apache 2.0

This program is free software: you can redistribute it and/or modify
it under the terms of the Apache License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Apache License for more details.

You should have received a copy of the Apache License
along with this program. If not, see <http://www.apache.org/licenses/LICENSE-2.0>.
"""

In [7]:
from model_config import MODELS_TOP1

In [8]:
MODELS = MODELS_TOP1

In [9]:
# Initialize the lorax client with the URL from the MODELS setup
from lorax import Client
client = Client(MODELS[0]["endpoints"][0]["url"])

In [10]:
def generate_solutions(task, tests, k=10, model_index=0, use_adapter=False, adapter_id=None):
    solutions = []
    model = MODELS[model_index]  # Select the model configuration based on the provided index
    generate_params = model["parameters"]  # Get generation parameters from the model configuration
    
    # Prepare the prompt using the model's chatPromptTemplate if necessary
    if model["preprompt"]:
        prompt = model["preprompt"] + task + tests
    prompt_template = model["chatPromptTemplate"].format(task=task, tests=tests)
    
    for _ in range(k):  # Assuming each call generates one solution
        if use_adapter and adapter_id:
            generate_params["adapter_id"] = adapter_id
        else:
            generate_params.pop("adapter_id", None)  # Remove adapter_id if it exists
        
        # Generate the solution
        response = client.generate(prompt_template, **generate_params)
        solutions.append(response.generated_text)
    return solutions

In [11]:
from datasets import load_dataset

In [12]:
# Load the MBPP dataset
dataset_full = load_dataset("mbpp")

In [13]:
from utils import eval_llm_top_k, eval_llm_output

In [17]:
index = 1
prompt = dataset_full['test'][index]['text']
test_cases = dataset_full['test'][index]['test_list']
solutions = generate_solutions(prompt,str(test_cases), k=1)

[INST] You are an expert Python programmer, and here is your task: Write a function to sort a given matrix in ascending order according to the sum of its rows.
Your code should pass these tests:

['assert sort_matrix([[1, 2, 3], [2, 4, 5], [1, 1, 1]])==[[1, 1, 1], [1, 2, 3], [2, 4, 5]]', 'assert sort_matrix([[1, 2, 3], [-2, 4, -5], [1, -1, 1]])==[[-2, 4, -5], [1, -1, 1], [1, 2, 3]]', 'assert sort_matrix([[5,8,9],[6,4,3],[2,1,4]])==[[2, 1, 4], [6, 4, 3], [5, 8, 9]]']
Your code should start with a [PYTHON] tag and end with a [/PYTHON] tag.
 [/INST]


In [18]:
dataset_full['test'][index]

{'task_id': 12,
 'text': 'Write a function to sort a given matrix in ascending order according to the sum of its rows.',
 'code': 'def sort_matrix(M):\r\n    result = sorted(M, key=sum)\r\n    return result',
 'test_list': ['assert sort_matrix([[1, 2, 3], [2, 4, 5], [1, 1, 1]])==[[1, 1, 1], [1, 2, 3], [2, 4, 5]]',
  'assert sort_matrix([[1, 2, 3], [-2, 4, -5], [1, -1, 1]])==[[-2, 4, -5], [1, -1, 1], [1, 2, 3]]',
  'assert sort_matrix([[5,8,9],[6,4,3],[2,1,4]])==[[2, 1, 4], [6, 4, 3], [5, 8, 9]]'],
 'test_setup_code': '',
 'challenge_test_list': []}

In [19]:
print(prompt)

Write a function to sort a given matrix in ascending order according to the sum of its rows.


In [20]:
print(dataset_full['test'][index]['code'])

def sort_matrix(M):
    result = sorted(M, key=sum)
    return result


In [21]:
print(solutions[0])

 
[PYTHON]
def sort_matrix(matrix):
    return sorted(matrix, key=lambda x: sum(x))
[/PYTHON]



In [22]:
test_cases

['assert sort_matrix([[1, 2, 3], [2, 4, 5], [1, 1, 1]])==[[1, 1, 1], [1, 2, 3], [2, 4, 5]]',
 'assert sort_matrix([[1, 2, 3], [-2, 4, -5], [1, -1, 1]])==[[-2, 4, -5], [1, -1, 1], [1, 2, 3]]',
 'assert sort_matrix([[5,8,9],[6,4,3],[2,1,4]])==[[2, 1, 4], [6, 4, 3], [5, 8, 9]]']

In [23]:
eval_llm_output(solutions[0], test_cases)

True

In [14]:
from tqdm.notebook import tqdm  # Import tqdm for Jupyter Notebook

In [16]:
# Initialize counters for pass@k metrics
pass_at_1 = pass_at_10 = pass_at_100 = 0

# Wrapping dataset_full['test'] with tqdm for progress visualization
for i, problem in tqdm(enumerate(dataset_full['test']), desc="Processing Problems", total=len(dataset_full['test'])):
    task = problem['text']
    test_cases = problem['test_list']
   
    # Generate solutions using the base LLM or a LoRA adapter
    # Rarely the predibase server times out, this will retry up to 3 times before going to next sample 
    for _ in range(3):  # Attempt to call generate_solutions up to 3 times
        try:
            solutions = generate_solutions(task, str(test_cases), k=1)
            break  # If call is successful, break out of the loop
        except Exception as e:
            print('Attempt failed with error:', e)
            if _ == 2:  # If this was the last attempt, prepare to continue the outer loop
                print('Could not get response from LLM server after 3 attempts.')
    else:
        continue  # Continue the outer loop if all attempts fail
        
    print(f"evaluating problem {i}")
    print("task: ", task)
    print("solutions: ", solutions)

    # Evaluate solutions
    #if eval_llm_output(solutions[0], test_cases):
    if eval_llm_top_k(solutions, test_cases, [1])[0]:
        pass_at_1 += 1
        print('passed')
    else:
        print('failed')

    print(f'running average {pass_at_1/(i+1)}, {pass_at_1} correct, {i+1} total')
    print()

Processing Problems:   0%|          | 0/500 [00:00<?, ?it/s]

evaluating problem 0
task:  Write a python function to remove first and last occurrence of a given character from the string.
solutions:  [" Here's my solution:\n[PYTHON]\ndef remove_occ(string, char):\n    return string[:string.index(char)] + string[string.index(char) + 1:] if char in string else string\n[/PYTHON]"]
test case failed: 
failed
running average 0.0, 0 correct, 1 total

evaluating problem 1
task:  Write a function to sort a given matrix in ascending order according to the sum of its rows.
solutions:  [' \n[PYTHON]\ndef sort_matrix(matrix):\n    return sorted(matrix, key=sum)\n[/PYTHON]\n']
passed
running average 0.5, 1 correct, 2 total



KeyboardInterrupt: 

In [42]:
pass_at_1 / 500

0.29

In [None]:
58,76 is interesting why it fails
