## Setup

### Imports etc

In [None]:
import gzip
import json
from typing import List
from collections import Counter, defaultdict
from tqdm import tqdm
import os
import inspect_ai
from openai import OpenAI, AsyncOpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from dotenv import load_dotenv
import subprocess
import contextlib
import shutil
from datetime import datetime

In [615]:
from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.util import ExecResult, sandbox
from inspect_ai.scorer import CORRECT, INCORRECT, Score, Scorer, Target, accuracy, scorer, stderr
from inspect_ai.solver import TaskState, generate
from inspect_ai.model import get_model
from inspect_ai.log import read_eval_log
import re

In [616]:
load_dotenv()
login(token = os.environ['HF_TOKEN'])
client = OpenAI()
async_client = AsyncOpenAI()

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [None]:
IMPORT_HELPER = {
    "python": [
        "import math",
        "import re",
        "import sys",
        "import copy",
        "import datetime",
        "import itertools",
        "import collections",
        "import heapq",
        "import statistics",
        "import functools",
        "import hashlib",
        "import numpy",
        "import numpy as np",
        "import string",
        "from typing import *",
        "from collections import *",
    ],
    "go"    : [
        "math",
        "strings",
        "fmt",
        "strconv",
        "time",
        "bytes",
        "regexp",
        "sort",
        "math/rand",
        "crypto/md5",
        "encoding/hex",
    ],
    "cpp"   : [
        "#include<stdlib.h>",
        "#include<algorithm>",
        "#include<math.h>",
        "#include<stdio.h>",
        "#include<vector>",
        "#include<string>",
        "#include<climits>",
        "#include<cstring>",
        "#include<iostream>",
        "#include <numeric>",
        "#include <sstream>",
        "#include <stack>",
        "#include <cctype>",
        "#include <set>",
        "#include <unordered_set>",
        "#include <iomanip>",
    ],
}

# instruction prepended to code problem
HUMANEVAL_INSTRUCTION = """
Read the following function signature and docstring, and fully implement
the function described. Include a function signature, that's exactly the 
same as the one provided in the prompt. Do NOT include any explanations. 
Also do NOT include the `main` function\n
"""

# JAVA
"""
Read the following function signature and docstring, and fully implement
the method described. Ensure your method signature is exactly the same as 
the one provided in the prompt. Ensure you include the `class Solution {}` 
classname. Your response should only contain the code, no explanations. 
Example outputs are provided below. 
```java
class Solution {
    public void testMethod1() {
        // method code
    }
}
```

```java
class Solution {
    public bool testMethod2(int a, int b) {
        // method code
    }

    public int helperMethod() {
        // helper method code.
    }
}
```\n
"""

# EVERYTHING ELSE
"""
Read the following function signature and docstring, and fully implement
the function described. Include a function signature, that's exactly the 
same as the one provided in the prompt. Do NOT include any explanations. 
Also do NOT include the `main` function\n
"""


CODE_EXTRACTION_INSTRUCTION = """
Here is a section of code. Remove any leading package names, import statements, and comments. 
Leave any indentation as it is (ie. don't remove leading whitespace for any line).\n
"""

# JAVA
"""
Here is a section of code. Remove any import statements and block comments. 
Leave any indentation as it is (ie. don't remove leading whitespace for any line).\n
"""

# GO
"""
Here is a section of code. Remove any leading package names and comments. 
Leave any indentation as it is (ie. don't remove leading whitespace for any line).
Also leave any import statements, if present.\n
"""

# EVERYTHING ELSE
"""
Here is a section of code. Remove any leading package names, import statements, and comments. 
Leave any indentation as it is (ie. don't remove leading whitespace for any line).\n
"""


LANG_PREFIX = {
    "cpp"          : "// language: C++",
    "java"         : "// language: Java",
    "js"           : "// language: JavaScript",
    "javascript"   : "// language: JavaScript",
    "go"           : "// language: Go",
    "python"       : "# language: Python",
}

In [618]:
llama_model = get_model(
        'hf/meta-llama/Llama-3.1-8B-Instruct', 
        device='auto',
        torch_dtype="auto"
)

qwen_coder = get_model(
    'hf/Qwen/Qwen2.5-Coder-7B-Instruct',
    device="auto",
    torch_dtype="auto",
)

### Language Setup

In [619]:
# JavaScript
#     apt-get update
#     apt-get install -y curl
#     curl -fsSL https://deb.nodesource.com/setup_lts.x | bash -
#     apt-get install -y nodejs
#     verify installation: `node -v` // `npm -v` // `node -e "console.log('Node.js is working')"

In [620]:
# Java
#     apt-get update
#     apt-get install -y openjdk-21-jdk
# Verify with:  
#     java -version
#     javac -version

In [621]:
# Go
#     cd ~/
#     GO_VERSION=1.24.5
#     wget https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz
#     rm -rf /usr/local/go
#     tar -C /usr/local -xzf go${GO_VERSION}.linux-amd64.tar.gz
#     echo 'export PATH=$PATH:/usr/local/go/bin' >> ~/.bashrc
#     export PATH=$PATH:/usr/local/go/bin
#     go version

In [622]:
# C++
#     apt-get update
#     apt-get install -y g++
#     apt-get install -y build-essential
#     apt-get install libboost-all-dev -y
#     apt-get install libssl-dev


## Data Processing

In [643]:
def stream_jsonl_all(filename: str):
    results = []
    fp = gzip.open(open(filename, "rb"), "rt")
    for line in fp:
        if any(not x.isspace() for x in line):
            results.append(json.loads(line))
    fp.close()

    return results

In [644]:
python_content = stream_jsonl_all('data/python_data.gz')
cpp_content = stream_jsonl_all('data/cpp_data.gz')
go_content = stream_jsonl_all('data/go_data.gz')
java_content = stream_jsonl_all('data/java_data.gz')
js_content = stream_jsonl_all('data/js_data.gz')
content = [python_content, cpp_content, go_content, java_content, js_content]

In [645]:
generations = stream_jsonl_all('data/python_generations.gz')
generations[0]['generation']

'    for idx, elem in enumerate(numbers):\n        for idx2, elem2 in enumerate(numbers):\n            if idx != idx2:\n                distance = abs(elem - elem2)\n                if distance < threshold:\n                    return True\n\n    return False\n'

In [646]:
for lang in content:
    print(lang[0].keys())

dict_keys(['task_id', 'prompt', 'canonical_solution', 'test', 'text', 'declaration', 'example_test'])
dict_keys(['task_id', 'prompt', 'canonical_solution', 'test', 'declaration', 'example_test'])
dict_keys(['task_id', 'prompt', 'import', 'docstring', 'declaration', 'canonical_solution', 'test', 'test_setup', 'example_test'])
dict_keys(['task_id', 'prompt', 'canonical_solution', 'test', 'text', 'declaration', 'example_test'])
dict_keys(['task_id', 'prompt', 'canonical_solution', 'test', 'declaration', 'example_test'])


## Eval Pipeline

### Extract Code

In [647]:
def identify_codeblock(completion: str) -> str:
    pattern_1 = re.compile(r"```(?:python|javascript|java|cpp|go)\n(.*?)```", re.DOTALL)
    pattern_2 = re.compile(r"```\n(.*?)```", re.DOTALL)
    matches = pattern_1.findall(completion) + pattern_2.findall(completion)

    if matches == []:
        return completion
    else:
        return matches[0]

In [648]:
async def remove_signature(completion, lang):
    prompt = CODE_EXTRACTION_INSTRUCTION + completion

    response = await async_client.responses.create(
        model='gpt-4.1-mini',
        input=prompt,
    )

    text_out = response.output[-1].content[0].text
    return text_out

In [None]:
async def find_code(completion, lang):
    processed = await remove_signature(completion, lang)
    processed = identify_codeblock(processed)

    return processed

### Languages

In [None]:
# CPP

def get_final_cpp(state, completion):
    imports = ''
    for s in IMPORT_HELPER['cpp']:
        if s not in state.metadata['prompt']:
            imports += s + '\n'
    
    prompt = state.metadata['prompt']
    declaration = state.metadata['declaration']
    header = declaration.strip().split('\n')[-1]
    updated_prompt = ''.join(prompt.split(header)[:-1])

    code = imports + "\n" + updated_prompt + completion + "\n" + state.metadata['test']
    
    return code

In [None]:
# GO

def get_final_go(state, completion):
    import_string = state.metadata['import']
    prompt = state.metadata['prompt'].replace(import_string, '')
    prompt = ''.join(prompt.split('func ')[:-1])

    test = state.metadata['test']
    test_setup = state.metadata['test_setup']
    other_pkgs = []

    for pkg in IMPORT_HELPER['go']:
        if pkg not in test_setup:
            p = pkg.split('/')[-1]
            if p + '.' in completion:    
                other_pkgs.append(f"\"{pkg}\"")
    if other_pkgs:
        import_other_pkgs = "import (\n" + "    ".join([p + "\n" for p in other_pkgs]) + ")"
        final_code = test_setup + "\n" + import_other_pkgs + "\n" + prompt + completion + "\n" + test
    else:
        final_code = test_setup + "\n" + prompt + completion + "\n" + test

    return final_code

In [None]:
# JAVA

def get_final_java(state, completion):
    prompt = state.metadata['prompt']
    prompt = ''.join(prompt.split('class Solution')[:-1])

    final_code = prompt + completion + "\n\n" + state.metadata['test'] + "\n"
    return final_code

In [None]:
# JS

def get_final_js(state, completion):
    prompt = state.metadata['prompt']
    prompt = ''.join(prompt.split('const ')[:-1])

    final_code = prompt + completion + "\n\n" + state.metadata['test'] + "\n"
    return final_code

In [None]:
# PYTHON

def get_final_python(state, completion):
    imports = "\n".join(IMPORT_HELPER["python"]) + "\n"
    prompt = state.metadata['prompt']
    prompt = ''.join(prompt.split('def ')[:-1])

    final_code = imports + prompt + completion + "\n" + state.metadata['test'] + "\n"
    return final_code

### Routing

In [None]:
async def get_final(state, lang, task_id):
    model_completion = state.output.completion
    processed_completion = await find_code(model_completion, lang)

    final = globals()[f'get_final_{lang}']
    final_code = final(state, processed_completion)

    if 'errormsg' in final_code:
        print(f'error in sample: {task_id}')
    
    return model_completion, processed_completion, final_code

### Scorers

In [655]:
async def python_scorer(final_code, *args):
    try:
        result = await sandbox().exec(
            cmd=["python", "-c", final_code],
            timeout=30,
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")
    
    return result

In [656]:
async def js_scorer(final_code, *args):
    try:
        result = await sandbox().exec(
            cmd=["node", "-e", final_code],
            timeout=30,
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")

    return result

In [657]:
async def go_scorer(final_code, idx, tmp_dir):
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    file = os.path.join(tmp_dir, 'main_test.go')

    with contextlib.chdir(tmp_dir):
        open(file, 'w').write(final_code)
        if not os.path.exists('go.mod'):
            try:
                subprocess.run(
                    ['/usr/local/go/bin/go', 'mod', 'init', f'example.com/tmpmod_{idx}'],
                    check=True, 
                    capture_output=True,
                )
            except subprocess.CalledProcessError as e:
                print("Error running go mod init:")
                print(e.stderr)
        subprocess.run(
            ['/usr/local/go/bin/go', 'mod', 'tidy'], 
            check=True, 
            capture_output=True,
        )

    try:
        result = await sandbox().exec(
            cmd=["/usr/local/go/bin/go", "test", file],
            timeout=30,
            cwd=tmp_dir
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")

    shutil.rmtree(tmp_dir)

    return result

In [658]:
async def java_scorer(final_code, idx, tmp_dir):
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    file = os.path.join(tmp_dir, 'Main.java')

    with contextlib.chdir(tmp_dir):
        open(file, 'w').write(final_code)
    
    try:
        compile_proc = subprocess.run(
            ["javac", "Main.java"],
            cwd=tmp_dir,
            capture_output=True,
            text=True,
            timeout=30  
        )
        if compile_proc.returncode != 0:
            print(f"Compilation failed! Idx: {idx}")
            print("stderr:", compile_proc.stderr)
    except subprocess.TimeoutExpired:
        print("Compilation timed out!")
    except Exception as e:
        print("Compilation error:", e)

    try:
        result = await sandbox().exec(
            cmd=["java", "-cp", tmp_dir, "Main"],
            timeout=30,
        )
    except TimeoutError:
        result = ExecResult(False, 1, "", "Verification timed out.")

    shutil.rmtree(tmp_dir)

    return result

In [659]:
async def cpp_scorer(final_code, idx, tmp_dir):
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    file = os.path.join(tmp_dir, 'test.cpp')
    executable = os.path.join(tmp_dir, 'test.out')

    open(file, 'w').write(final_code)
    
    try:
        compile_proc = subprocess.run(
            ["g++", "-std=c++17", file, "-o", executable, '-lssl', '-lcrypto'],
            capture_output=True,
            text=True,
            timeout=30  # seconds, adjust as needed
        )
        if compile_proc.returncode != 0:
            print(f"Compilation failed! task number: {idx}")
            print("stderr:", compile_proc.stderr)
    except subprocess.TimeoutExpired:
        print("Compilation timed out!")

    if os.path.exists(executable):
        try:
            result = await sandbox().exec(
                cmd=[executable],
                timeout=30
            )
        except TimeoutError:
            result = ExecResult(False, 1, "", "Verification timed out.")
        except Exception as e:
            print(f'execution failed cuz of: {e}')
    else:
        result = ExecResult(False, 1, "", "Compiler Error")

    shutil.rmtree(tmp_dir)

    return result

### Inspect Pipeline

In [660]:
def get_lang_idx(task_id: str):
    lang, idx = task_id.split('/')
    lang = lang.lower()
    if lang == 'javascript':
        lang = 'js'
    idx = int(idx)

    return lang, idx

@scorer(metrics=[accuracy(), stderr()])
def main_scorer() -> Scorer:
    async def score(state: TaskState, target: Target) -> Score:
        task_id = state.sample_id
        lang, idx = get_lang_idx(task_id)

        model_completion, processed_completion, final_code = await get_final(state, lang, task_id)

        tmp_dir = f'/root/srf-project/test_humaneval-x/tmp/test_{idx}/'
        my_scorer = globals()[f'{lang}_scorer']
        result = await my_scorer(final_code, idx, tmp_dir)

        success = result.success and (result.stderr == '')
        
        return Score(
            value=CORRECT if success else INCORRECT,
            explanation="".join(
                ["The following verification code was executed:\n\n"]
                + [final_code]
                + [f"\nThe submission was incorrect\n\n{result.stderr}"]
                if not result.success
                else [""]
            ),
            metadata={
                'completion': model_completion,
                'processed': processed_completion,
                'final_code': final_code,
                'idx': idx,
                'task_id': task_id,
            },
        )
    
    return score

In [666]:
language = 'python'

def humaneval_record_to_sample(record):
    model_input = HUMANEVAL_INSTRUCTION + LANG_PREFIX[language] + '\n' + record['prompt'] 

    idx = int(record['task_id'].split('/')[-1])

    metadata = {
        "prompt": record["prompt"],
        "test": record["test"],
        "declaration": record["declaration"]
    }
    if language == 'go':
        metadata['import'] = go_content[idx]['import']
        metadata['test_setup'] = go_content[idx]['test_setup']
    
    return Sample(
        id=record["task_id"],
        input=model_input,
        target=record["canonical_solution"],
        metadata=metadata,
    )

humaneval_dataset = hf_dataset(
    path = 'THUDM/humaneval-x',
    name = language,
    split = 'test',
    sample_fields = humaneval_record_to_sample,
    trust = True,
)

In [667]:
# samples = 30

@task
def humaneval():
    return Task(
        dataset = humaneval_dataset,
        solver = generate(),
        scorer = main_scorer(),
        sandbox = 'local',
    )

## Experimentation

In [668]:
epochs = 1
model = 'openai/gpt-4o-mini'

inspect_ai.eval(
    humaneval(), 
    model = model, 
    epochs = epochs,
    log_dir = '/root/srf-project/test_humaneval-x/baseline_performance/python/'
)

Output()

In [669]:
epochs = 1
model = llama_model

inspect_ai.eval(
    humaneval(), 
    model = model, 
    epochs = epochs,
    log_dir = '/root/srf-project/test_humaneval-x/baseline_performance/python/'
)

Output()

In [670]:
epochs = 1
model = qwen_coder

inspect_ai.eval(
    humaneval(), 
    model = model, 
    epochs = epochs,
    log_dir = '/root/srf-project/test_humaneval-x/baseline_performance/python/'
)

Output()

In [None]:
path = '/root/srf-project/test_humaneval-x/pipeline_check/python/llama_8b.eval'
log = read_eval_log(path)

def eval_duration(log):
    start_time = log.stats.started_at
    start_time = start_time.split('T')[-1].split('+')[0]

    end_time = log.stats.completed_at
    end_time = end_time.split('T')[-1].split('+')[0]
    
    fmt = "%H:%M:%S"
    t1 = datetime.strptime(start_time, fmt)
    t2 = datetime.strptime(end_time, fmt)

    duration = (t2-t1).seconds

    return f'{duration} seconds'

def eval_score(log):
    score = log.results.scores[0].metrics['accuracy'].value
    return f'{score * 100:.2f}%'
    
print(eval_duration(log))
print(eval_score(log))