## Imports + Setup

In [None]:
import gzip
import json
from typing import List
from collections import Counter, defaultdict
from tqdm import tqdm
import os
import inspect_ai
from openai import OpenAI, AsyncOpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from dotenv import load_dotenv
import torch as t
import subprocess
import contextlib
import shutil
import ast
import copy
from openai import OpenAI
from datetime import datetime
from dataclasses import dataclass

In [None]:
from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.util import ExecResult, sandbox
from inspect_ai.scorer import CORRECT, INCORRECT, Score, Scorer, Target, accuracy, scorer, stderr
from inspect_ai.solver import TaskState, generate
from inspect_ai.model import get_model
from inspect_ai.log import read_eval_log
import re

In [None]:
load_dotenv()
login(token = os.environ['HF_TOKEN'])
client = OpenAI()
async_client = AsyncOpenAI()

In [None]:
IMPORT_HELPER = {
    "python": [
        "import math",
        "import re",
        "import sys",
        "import copy",
        "import datetime",
        "import itertools",
        "import collections",
        "import heapq",
        "import statistics",
        "import functools",
        "import hashlib",
        "import numpy",
        "import numpy as np",
        "import string",
        "from typing import *",
        "from collections import *",
    ],
    "go"    : [
        "math",
        "strings",
        "fmt",
        "strconv",
        "time",
        "bytes",
        "regexp",
        "sort",
        "math/rand",
        "crypto/md5",
    ],
    "cpp"   : [
        "#include<stdlib.h>",
        "#include<algorithm>",
        "#include<math.h>",
        "#include<stdio.h>",
        "#include<vector>",
        "#include<string>",
        "#include<climits>",
        "#include<cstring>",
        "#include<iostream>",
        "#include <numeric>",
        "#include <sstream>",
        "#include <stack>",
    ],
}

# instruction prepended to code problem
HUMANEVAL_INSTRUCTION = """
Read the following function signature and docstring, and fully implement
the function described. Your response should only contain the code for
this function.\n
"""

CODE_EXTRACTION_INSTRUCTION = """
Here is a section of code. Follow these instructions to extract the function body.
First, if present, remove any leading package names, import statements, and comments. 
Next, remove the function signature of the first function you see. 
If there is a `main` function present, delete that entire function. 
However, if there are any other functions present, leave them as they are.
If the code is in python, ensure the function body is indented properly.

Your response should only contain the remaining block of code. \n
"""

LANG_PREFIX = {
    "cpp"          : "// language: C++",
    "java"         : "// language: Java",
    "js"           : "// language: JavaScript",
    "javascript"   : "// language: JavaScript",
    "go"           : "// language: Go",
    "python"       : "# language: Python",
}

In [None]:
llama_model = get_model(
        'hf/meta-llama/Llama-3.1-8B-Instruct', 
        device='auto',
        torch_dtype="auto"
)

qwen_coder = get_model(
    'hf/Qwen/Qwen2.5-Coder-7B-Instruct',
    device="auto",
    torch_dtype="auto",
)

# deepseek_coder = get_model(
#     'hf/deepseek-ai/deepseek-coder-6.7b-instruct',
#     device='auto',
#     torch_dtype='auto',
# )

### Few-Shot Prompting

In [None]:
prompt_template = """
```python

    ```
    """

prompt_1 = """
```python
import math

def right_angle_triangle(a, b, c):
    '''
    Given the lengths of the three sides of a triangle. Return True if the three
    sides form a right-angled triangle, False otherwise.
    A right-angled triangle is a triangle in which one angle is right angle or 
    90 degree.
    Example:
    right_angle_triangle(3, 4, 5) == True
    right_angle_triangle(1, 2, 3) == False
    '''
    sides = sorted([a, b, c])
    return math.isclose(sides[0]**2 + sides[1]**2, sides[2]**2)
    ```
    """

response_1 = """
```python
    sides = sorted([a, b, c])
    return math.isclose(sides[0]**2 + sides[1]**2, sides[2]**2)
    ```
    """

prompt_2 = """
```python
def sorted_list_sum(lst):
    '''
    Deletes strings with odd lengths from a list and returns the resulted list 
    sorted in ascending order by string length and then alphabetically.

    Args:
        lst (list): A list of strings.

    Returns:
        list: A list of strings with odd lengths removed, sorted by length and then alphabetically.
    '''
    # Filter out strings with odd lengths
    even_length_strings = [string for string in lst if len(string) % 2 == 0]
    
    # Sort the list first by length and then alphabetically
    sorted_list = sorted(even_length_strings, key=lambda x: (len(x), x))
    
    return sorted_list
    ```
    """

response_2 = """
```python
    # Filter out strings with odd lengths
    even_length_strings = [string for string in lst if len(string) % 2 == 0]
    
    # Sort the list first by length and then alphabetically
    sorted_list = sorted(even_length_strings, key=lambda x: (len(x), x))
    
    return sorted_list
    ```
    """

prompt_3 = '''
```python
def words_in_sentence(sentence):
    ''''''
    Returns a string containing the words from the original sentence whose lengths are prime numbers.
    
    Args:
        sentence (str): A string representing a sentence.
    
    Returns:
        str: A string containing the words from the original sentence whose lengths are prime numbers.
    """
    def is_prime(n):
        """Checks if a number is prime."""
        if n < 2:
            return False
        for i in range(2, int(n ** 0.5) + 1):
            if n % i == 0:
                return False
        return True

    words = sentence.split()
    return ' '.join(word for word in words if is_prime(len(word)))
    ```
    '''

response_3 = '''
```python
    def is_prime(n):
        """Checks if a number is prime."""
        if n < 2:
            return False
        for i in range(2, int(n ** 0.5) + 1):
            if n % i == 0:
                return False
        return True

    words = sentence.split()
    return ' '.join(word for word in words if is_prime(len(word)))
    ```
    '''

prompt_4 = '''
```python
    """Evaluate whether the given number n can be written as the sum of exactly 4 positive even numbers
    Example
    is_equal_to_sum_even(4) == False
    is_equal_to_sum_even(6) == False
    is_equal_to_sum_even(8) == True
    """
    # The smallest sum of 4 positive even numbers (2 + 2 + 2 + 2) is 8
    if n < 8:
        return False
    
    # Since the sum of even numbers is even, n must be even
    if n % 2 != 0:
        return False
    
    # If n is at least 8 and is even, it can be expressed as the sum of 4 positive even numbers
    return True
    ```
    '''

response_4 = '''
```python
    # The smallest sum of 4 positive even numbers (2 + 2 + 2 + 2) is 8
    if n < 8:
        return False
    
    # Since the sum of even numbers is even, n must be even
    if n % 2 != 0:
        return False
    
    # If n is at least 8 and is even, it can be expressed as the sum of 4 positive even numbers
    return True
    ```
    '''

prompt_0 = """
```python
    txt = txt.rstrip()  # Remove trailing spaces
    if not txt:  # Check if the string is empty after stripping
        return False
    last_char = txt[-1]  # Get the last character of the string
    return last_char.isalpha() and (len(txt) == 1 or txt[-2] == ' ')  # Check conditions
    ```
    """

response_0 = """
```python
    txt = txt.rstrip()  # Remove trailing spaces
    if not txt:  # Check if the string is empty after stripping
        return False
    last_char = txt[-1]  # Get the last character of the string
    return last_char.isalpha() and (len(txt) == 1 or txt[-2] == ' ')  # Check conditions
    ```
    """


num_prompts = 5

PYTHON_PROMPT = [{
        'role': 'developer',
        'content': CODE_EXTRACTION_INSTRUCTION,
    }]

for i in range(5):
    user_msg = {
        'role': 'user',
        'content': globals()[f'prompt_{i}'],
    }
    PYTHON_PROMPT.append(user_msg)

    assistant_msg = {
        'role': 'assistant',
        'content': globals()[f'response_{i}']
    }
    PYTHON_PROMPT.append(assistant_msg)


### Language Setup

In [None]:
# JavaScript
#     apt-get update
#     apt-get install -y curl
#     curl -fsSL https://deb.nodesource.com/setup_lts.x | bash -
#     apt-get install -y nodejs
#     verify installation: `node -v` // `npm -v` // `node -e "console.log('Node.js is working')"

In [None]:
# Java
#     apt-get update
#     apt-get install -y openjdk-21-jdk
# Verify with:  
#     java -version
#     javac -version

In [None]:
# Go
#     cd ~/
#     GO_VERSION=1.24.5
#     wget https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz
#     rm -rf /usr/local/go
#     tar -C /usr/local -xzf go${GO_VERSION}.linux-amd64.tar.gz
#     echo 'export PATH=$PATH:/usr/local/go/bin' >> ~/.bashrc
#     export PATH=$PATH:/usr/local/go/bin
#     go version

In [None]:
# C++
#     apt-get update
#     apt-get install -y g++
#     apt-get install -y build-essential
#     apt-get install libboost-all-dev
#     apt-get install libssl-dev


## Data Processing

In [None]:
def stream_jsonl_all(filename: str):
    results = []
    fp = gzip.open(open(filename, "rb"), "rt")
    for line in fp:
        if any(not x.isspace() for x in line):
            results.append(json.loads(line))
    fp.close()

    return results

In [None]:
python_content = stream_jsonl_all('data/python_data.gz')
cpp_content = stream_jsonl_all('data/cpp_data.gz')
go_content = stream_jsonl_all('data/go_data.gz')
java_content = stream_jsonl_all('data/java_data.gz')
js_content = stream_jsonl_all('data/js_data.gz')
content = [python_content, cpp_content, go_content, java_content, js_content]

In [None]:
generations = stream_jsonl_all('data/python_generations.gz')
generations[0]['generation']

In [None]:
for lang in content:
    print(lang[0].keys())

## Eval Pipeline

### Extract Code

In [None]:
def identify_codeblock(completion: str) -> str:
    pattern_1 = re.compile(r"```(?:python|javascript|java|cpp|go)\n(.*?)```", re.DOTALL)
    pattern_2 = re.compile(r"```\n(.*?)```", re.DOTALL)
    matches = pattern_1.findall(completion) + pattern_2.findall(completion)

    if matches == []:
        return completion
    else:
        return matches[0]

In [None]:
async def remove_signature(completion):
    # prompt = CODE_EXTRACTION_INSTRUCTION + completion
    prompt = PYTHON_PROMPT + [{
        'role': 'user',
        'content': completion
    }]

    response = await async_client.responses.create(
        model='gpt-4.1-mini',
        input=prompt,
    )

    text_out = response.output[-1].content[0].text
    return text_out

### Languages

#### CPP

In [None]:
def balance_brackets_cpp(completion):
    if '{' in completion[:3]:
        completion = completion.replace('{', "", 1)

    difference = completion.count('{') - completion.count('}') 
    if difference == 0: 
        completion += '}'
        difference -= 1
    if difference != -1: 
        print('brackets ain\'t balancing')
    
    return completion

def find_code_cpp(completion: str) -> str:
    processed = remove_signature(completion)
    processed = identify_codeblock(processed)
    processed = balance_brackets_cpp(processed)

    return processed

In [None]:
def get_final_cpp(state, completion):
    imports = ''
    for s in IMPORT_HELPER['cpp']:
        if s not in state.metadata['prompt']:
            imports += s + '\n'

    code = imports + "\n" + state.metadata['prompt'] + completion + "\n" + state.metadata['test']
    
    return code

#### Go

In [None]:
def balance_brackets_go(completion):
    if '{' in completion[:3]:
        completion = completion.replace('{', "", 1)

    difference = completion.count('{') - completion.count('}') 
    if difference == 0: 
        completion += '\n}'
        difference -= 1
    if difference != -1: 
        print('brackets ain\'t balancing')
    
    return completion

def find_code_go(completion: str) -> str:
    processed = remove_signature(completion)
    processed = identify_codeblock(processed)
    processed = balance_brackets_go(processed)

    return processed

In [None]:
def get_final_go(state, completion):
    import_string = state.metadata['import']
    prompt = state.metadata['prompt'].replace(import_string, '')

    test = state.metadata['test']
    test_setup = state.metadata['test_setup']
    other_pkgs = []

    for pkg in IMPORT_HELPER['go']:
        if pkg not in test_setup:
            p = pkg.split('/')[-1]
            if p + '.' in completion:    
                other_pkgs.append(f"\"{pkg}\"")
    if other_pkgs:
        import_other_pkgs = "import (\n" + "    ".join([p + "\n" for p in other_pkgs]) + ")"
        final_code = test_setup + "\n" + import_other_pkgs + "\n" + prompt + completion + "\n" + test
    else:
        final_code = test_setup + "\n" + prompt + completion + "\n" + test

    return final_code

#### Java

In [None]:
def remove_unindented(lines, full_func):
    while (not lines[0]) or (lines[0] == lines[0].lstrip()):
        full_func = True
    return lines, full_func

def remove_signature_java(lines, full_func):
    '''
    deals with three cases
    a) first lines are comments, before method header
    b) first line is method-header
    c) b), but with incorrect indentation? so header had accidentally been removed?
    '''
    if not (full_func or lines[0].lstrip()[:6] == 'public'):
        return lines
    
    removed = False
    og_lines = copy.deepcopy(lines)
    
    while not removed:
        if lines == []:
            return og_lines
        
        line = lines.pop(0)
        if line.lstrip()[:6] == 'public':
            removed = True
        
    return lines

def balance_brackets_java(processed):
    difference = processed.count('{') - processed.count('}') + 2
    assert difference in [0, 1, 2], 'brackets ain\'t balancing'
    return processed + ('}' * difference)

def find_code_java(completion: str) -> str:
    code = identify_codeblock(completion)
    lines = code.splitlines()
    lines, full_func = remove_unindented(lines, False)
    lines = remove_signature_java(lines, full_func)

    processed = '\n'.join(lines)
    processed = balance_brackets_java(processed)
    
    return processed

In [None]:
def get_final_java(state, completion):
    final_code = state.metadata['prompt'] + completion + "\n\n" + state.metadata['test'] + "\n"
    return final_code

#### Java [archived]

In [None]:
# def remove_method_header(lines,):
#     '''
#     deals with three cases
#     a) first lines are comments, before method header
#     b) first line is method-header
#     c) b), but with incorrect indentation? so header had accidentally been removed?
#     '''
#     removed = False
#     og_lines = copy.deepcopy(lines)
    
#     while not removed:
#         if lines == []:
#             return og_lines
        
#         line = lines.pop(0)
#         if line.lstrip()[:6] == 'public':
#             removed = True
        
#     return lines

# def balance_brackets(lines):
#     final_code = '\n'.join(lines)
#     difference = final_code.count('{') - final_code.count('}') + 2
#     assert difference in [0, 1, 2], 'brackets ain\'t balancing'
#     return final_code + ('}' * difference)

# def remove_unindented(lines, full_func):
#     import_statements = []
#     while (not lines[0]) or (lines[0] == lines[0].lstrip()):
#         line = lines.pop(0)
#         full_func = True
#         if 'import' in line: import_statements.append(line)
    
#     return lines, import_statements, full_func

# def find_code_java(completion: str) -> str:
#     code = identify_codeblock(completion)
#     lines = code.splitlines()

#     lines, import_statements, full_func = remove_unindented(lines, False)

#     if full_func or lines[0].lstrip()[:6] == 'public':
#         lines = remove_method_header(lines)

#     processed_completion = balance_brackets(lines)
#     import_statements = '\n'.join(import_statements)
    
#     return processed_completion, import_statements

In [None]:
# def get_final_java(imports, state, completion):
#     final_code = imports + '\n' + state.metadata['prompt'] + completion + "\n\n" + state.metadata['test'] + "\n"
#     return final_code

#### JavaScript

In [None]:
def remove_signature_js(code: str) -> str:
    lines = code.splitlines()
    found = False
    for line in lines:
        if line.lstrip().startswith('const'):
            lines.remove(line)
            found = True
            break  

    if not found: 
        print('error extracting function body')
        return 'errormsg'
    return "\n".join(lines)

def find_code_js(completion: str) -> str:
    processed = identify_codeblock(completion)
    processed = remove_signature_js(processed)
    return processed

In [None]:
def get_final_js(state, completion):
    final_code = state.metadata['prompt'] + completion + "\n\n" + state.metadata['test'] + "\n"
    return final_code

#### Python

In [None]:
def remove_signature_python(code: str) -> str:
    try:
        tree = ast.parse(code)
        for node in tree.body:
            if not isinstance(node, ast.FunctionDef):
                continue
            code_lines = code.splitlines()
            start = node.body[0].lineno - 1
            end = node.body[-1].end_lineno
            body_lines = code_lines[start:end]
            return "\n".join(body_lines)

    except Exception as e:
        print(f"Error extracting function body: {e}")
        return "errormsg"
    
async def find_code_python(completion):
    processed = identify_codeblock(completion)

    processed = await remove_signature(processed)
    processed = identify_codeblock(processed)

    # processed = remove_signature_python(processed)
    return processed

In [None]:
def get_final_python(state, completion):
    imports = "\n".join(IMPORT_HELPER["python"]) + "\n"
    final_code = imports + state.metadata['prompt'] + completion + "\n" + state.metadata['test'] + "\n"
    return final_code

### Routing

In [None]:
async def get_final(state, lang, task_id):
    model_completion = state.output.completion

    process_code = globals()[f'find_code_{lang}']
    processed_completion = await process_code(model_completion)

    final = globals()[f'get_final_{lang}']
    final_code = final(state, processed_completion)

    if 'errormsg' in final_code:
        print(f'error in sample: {task_id}')
    
    return model_completion, processed_completion, final_code

### Scorers

In [None]:
async def python_scorer(final_code, *args):
    try:
        result = await sandbox().exec(
            cmd=["python", "-c", final_code],
            timeout=30,
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")
    
    return result

In [None]:
async def js_scorer(final_code, *args):
    try:
        result = await sandbox().exec(
            cmd=["node", "-e", final_code],
            timeout=30,
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")

    return result

In [None]:
async def go_scorer(final_code, idx, tmp_dir):
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    file = os.path.join(tmp_dir, 'main_test.go')

    with contextlib.chdir(tmp_dir):
        open(file, 'w').write(final_code)
        if not os.path.exists('go.mod'):
            try:
                subprocess.run(
                    ['/usr/local/go/bin/go', 'mod', 'init', f'example.com/tmpmod_{idx}'],
                    check=True, 
                    capture_output=True,
                )
            except subprocess.CalledProcessError as e:
                print("Error running go mod init:")
                print(e.stderr)
        subprocess.run(
            ['/usr/local/go/bin/go', 'mod', 'tidy'], 
            check=True, 
            capture_output=True,
        )

    try:
        result = await sandbox().exec(
            cmd=["/usr/local/go/bin/go", "test", file],
            timeout=30,
            cwd=tmp_dir
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")

    shutil.rmtree(tmp_dir)

    return result

In [None]:
async def java_scorer(final_code, idx, tmp_dir):
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    file = os.path.join(tmp_dir, 'Main.java')

    with contextlib.chdir(tmp_dir):
        open(file, 'w').write(final_code)
    
    try:
        compile_proc = subprocess.run(
            ["javac", "Main.java"],
            cwd=tmp_dir,
            capture_output=True,
            text=True,
            timeout=30  
        )
        if compile_proc.returncode != 0:
            print(f"Compilation failed! Idx: {idx}")
            print("stderr:", compile_proc.stderr)
    except subprocess.TimeoutExpired:
        print("Compilation timed out!")
    except Exception as e:
        print("Compilation error:", e)

    try:
        result = await sandbox().exec(
            cmd=["java", "-cp", tmp_dir, "Main"],
            timeout=30,
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")

    shutil.rmtree(tmp_dir)

    return result

In [None]:
async def cpp_scorer(final_code, idx, tmp_dir):
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    file = os.path.join(tmp_dir, 'test.cpp')
    executable = os.path.join(tmp_dir, 'test.out')

    open(file, 'w').write(final_code)
    
    try:
        compile_proc = subprocess.run(
            ["g++", "-std=c++17", file, "-o", executable, '-lssl', '-lcrypto'],
            capture_output=True,
            text=True,
            timeout=30  # seconds, adjust as needed
        )
        if compile_proc.returncode != 0:
            print(f"Compilation failed! task number: {idx}")
            print("stderr:", compile_proc.stderr)
    except subprocess.TimeoutExpired:
        print("Compilation timed out!")

    if os.path.exists(executable):
        try:
            result = await sandbox().exec(
                cmd=[executable],
                timeout=30
            )
        except TimeoutError:
            result = ExecResult(False, 1, "", "Verification timed out.")
        except Exception as e:
            print(f'execution failed cuz of: {e}')
    else:
        result = ExecResult(False, 1, "", "Compiler Error")

    shutil.rmtree(tmp_dir)

    return result

### Inspect Pipeline

In [None]:
def get_lang_idx(task_id: str):
    lang, idx = task_id.split('/')
    lang = lang.lower()
    if lang == 'javascript':
        lang = 'js'
    idx = int(idx)

    return lang, idx

@scorer(metrics=[accuracy(), stderr()])
def main_scorer() -> Scorer:
    async def score(state: TaskState, target: Target) -> Score:
        task_id = state.sample_id
        lang, idx = get_lang_idx(task_id)

        model_completion, processed_completion, final_code = await get_final(state, lang, task_id)

        tmp_dir = f'/root/srf-project/test_humaneval-x/tmp/test_{idx}/'
        my_scorer = globals()[f'{lang}_scorer']
        result = await my_scorer(final_code, idx, tmp_dir)

        success = result.success and (result.stderr == '')
        
        return Score(
            value=CORRECT if success else INCORRECT,
            explanation="".join(
                ["The following verification code was executed:\n\n"]
                + [final_code]
                + [f"\nThe submission was incorrect\n\n{result.stderr}"]
                if not result.success
                else [""]
            ),
            metadata={
                'completion': model_completion,
                'processed': processed_completion,
                'final_code': final_code,
                'idx': idx,
                'task_id': task_id,
            },
        )

    return score

In [None]:
language = 'python'

def humaneval_record_to_sample(record):
    model_input = HUMANEVAL_INSTRUCTION + LANG_PREFIX[language] + '\n' + record['prompt'] 

    idx = int(record['task_id'].split('/')[-1])

    metadata = {
        "prompt": record["prompt"],
        "test": record["test"],
    }
    if language == 'go':
        metadata['import'] = go_content[idx]['import']
        metadata['test_setup'] = go_content[idx]['test_setup']
    
    return Sample(
        id=record["task_id"],
        input=model_input,
        target=record["canonical_solution"],
        metadata=metadata,
    )

humaneval_dataset = hf_dataset(
    path = 'THUDM/humaneval-x',
    name = language,
    split = 'test',
    sample_fields = humaneval_record_to_sample,
    trust = True,
)

In [None]:
samples = 30

@task
def humaneval():
    return Task(
        dataset = humaneval_dataset[-samples:],
        solver = generate(),
        scorer = main_scorer(),
        sandbox = 'local',
    )

## Experimentation

In [None]:
epochs = 1
inspect_ai.eval(
    humaneval(), 
    model = 'openai/gpt-4o-mini', 
    epochs = epochs,
    log_dir = '/root/srf-project/test_humaneval-x/pipeline_check/python/'
)

In [None]:
path = '/root/srf-project/test_humaneval-x/pipeline_check/python/llama_8b.eval'
log = read_eval_log(path)

def eval_duration(log):
    start_time = log.stats.started_at
    start_time = start_time.split('T')[-1].split('+')[0]

    end_time = log.stats.completed_at
    end_time = end_time.split('T')[-1].split('+')[0]
    
    fmt = "%H:%M:%S"
    t1 = datetime.strptime(start_time, fmt)
    t2 = datetime.strptime(end_time, fmt)

    duration = (t2-t1).seconds

    return f'{duration} seconds'

def eval_score(log):
    score = log.results.scores[0].metrics['accuracy'].value
    return f'{score * 100:.2f}%'
    
print(eval_duration(log))
print(eval_score(log))