### setup

In [None]:
import gzip
import json
from typing import List
from collections import Counter, defaultdict
from tqdm import tqdm
import os
import inspect_ai
from openai import OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from dotenv import load_dotenv
import torch as t
import subprocess
import contextlib
import shutil
import ast
import textwrap
import copy
from openai import OpenAI

In [None]:
from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.util import ExecResult, sandbox
from inspect_ai.scorer import CORRECT, INCORRECT, Score, Scorer, Target, accuracy, scorer, stderr
from inspect_ai.solver import TaskState, generate
from inspect_ai.model import get_model
from inspect_ai.log import read_eval_log
import re

In [None]:
load_dotenv()
login(token = os.environ['HF_TOKEN'])
client = OpenAI()

In [None]:
IMPORT_HELPER = {
    "python": [
        "import math",
        "import re",
        "import sys",
        "import copy",
        "import datetime",
        "import itertools",
        "import collections",
        "import heapq",
        "import statistics",
        "import functools",
        "import hashlib",
        "import numpy",
        "import numpy as np",
        "import string",
        "from typing import *",
        "from collections import *",
    ],
    "go"    : [
        "math",
        "strings",
        "fmt",
        "strconv",
        "time",
        "bytes",
        "regexp",
        "sort",
        "math/rand",
        "crypto/md5",
    ],
    "cpp"   : [
        "#include<stdlib.h>",
        "#include<algorithm>",
        "#include<math.h>",
        "#include<stdio.h>",
        "#include<vector>",
        "#include<string>",
        "#include<climits>",
        "#include<cstring>",
        "#include<iostream>",
    ],
}

# instruction prepended to code problem
HUMANEVAL_INSTRUCTION = """
Read the following function signature and docstring, and fully implement
the function described. Your response should only contain the code for
this function.\n
"""

CODE_EXTRACTION_INSTRUCTION = """
Here is a section of code. Follow these instructions to extract the function body.
First, if present, remove any initial package names, import statements, and comments. 
Next, remove the function signature of the first function you see. 
If there is a `main` function present, delete that entire function. 
However, if there are any other functions present, leave them as they are.

Your response should only contain the remaining block of code. \n
"""

LANG_PREFIX = {
    "cpp"          : "// language: C++",
    "java"         : "// language: Java",
    "js"           : "// language: JavaScript",
    "javascript"   : "// language: JavaScript",
    "go"           : "// language: Go",
    "python"       : "# language: Python",
}

In [None]:
model = get_model(
        'hf/meta-llama/Llama-3.1-8B-Instruct', 
        device = 'auto',
        torch_dtype=t.bfloat16,
)

### Data Processing

In [None]:
def stream_jsonl_all(filename: str):
    results = []
    fp = gzip.open(open(filename, "rb"), "rt")
    for line in fp:
        if any(not x.isspace() for x in line):
            results.append(json.loads(line))
    fp.close()

    return results

In [None]:
python_content = stream_jsonl_all('data/python_data.gz')
cpp_content = stream_jsonl_all('data/cpp_data.gz')
go_content = stream_jsonl_all('data/go_data.gz')
java_content = stream_jsonl_all('data/java_data.gz')
js_content = stream_jsonl_all('data/js_data.gz')
content = [python_content, cpp_content, go_content, java_content, js_content]

In [None]:
generations = stream_jsonl_all('data/python_generations.gz')
generations[0]['generation']

In [None]:
for lang in content:
    print(lang[0].keys())
    print()

### LLM output => code  
*I need to see LLM output & transform that into code*

In [None]:
# Installing Go
#     `cd ~/`  
#     `GO_VERSION=1.24.5`  
#     `wget https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz`  
#     `rm -rf /usr/local/go`  
#     `tar -C /usr/local -xzf go${GO_VERSION}.linux-amd64.tar.gz`  
#     `echo 'export PATH=$PATH:/usr/local/go/bin' >> ~/.bashrc`  
#     `export PATH=$PATH:/usr/local/go/bin`  
#     `go version`  

In [None]:
# FIND CODE

def remove_fxn_signature(completion):
    prompt = CODE_EXTRACTION_INSTRUCTION + completion
    response = client.responses.create(
        model='gpt-4.1-mini',
        input=prompt
    )

    text_out = response.output[-1].content[0].text
    return text_out

def balance_brackets(completion):
    if '{' in completion[:3]:
        completion = completion.replace('{', "", 1)

    difference = completion.count('{') - completion.count('}') 
    if difference == 0: 
        completion += '\n}'
        difference -= 1
    if difference != -1: 
        print('brackets ain\'t balancing')
    
    return completion


def identify_codeblock(completion):
    # remove lang flag
    pattern_1 = re.compile(r"```go\n(.*?)```", re.DOTALL)
    pattern_2 = re.compile(r"```\n(.*?)```", re.DOTALL)
    matches = pattern_1.findall(completion) + pattern_2.findall(completion)

    if matches == []:
        return completion
    else:
        return matches[0]

def find_code_go(completion: str) -> str:
    processed = remove_fxn_signature(completion)
    processed = identify_codeblock(processed)
    processed = balance_brackets(processed)

    return processed

In [None]:
def get_final_go(state, completion):
    import_string = state.metadata['import']
    prompt = state.metadata['prompt'].replace(import_string, '')

    test = state.metadata['test']
    test_setup = state.metadata['test_setup']
    other_pkgs = []

    for pkg in IMPORT_HELPER['go']:
        if pkg not in test_setup:
            p = pkg.split('/')[-1]
            if p + '.' in completion:    
                other_pkgs.append(f"\"{pkg}\"")
    if other_pkgs:
        import_other_pkgs = "import (\n" + "    ".join([p + "\n" for p in other_pkgs]) + ")"
        final_code = test_setup + "\n" + import_other_pkgs + "\n" + prompt + completion + "\n" + test
    else:
        final_code = test_setup + "\n" + prompt + completion + "\n" + test

    return final_code

#### Running Eval

In [None]:
@scorer(metrics=[accuracy(), stderr()])
def go_scorer() -> Scorer:
    async def score(state: TaskState, target: Target) -> Score:
        task_id = state.sample_id
        idx = int(task_id.split('/')[-1])

        model_completion = state.output.completion
        processed_completion = find_code_go(model_completion)
        final_code = get_final_go(state, processed_completion)

        tmp_dir = f'/root/srf-project/tmp/test_{idx}/'
        if not os.path.exists(tmp_dir):
            os.makedirs(tmp_dir, exist_ok=True)
        file = os.path.join(tmp_dir, 'main_test.go')

        with contextlib.chdir(tmp_dir):
            open(file, 'w').write(final_code)
            if not os.path.exists('go.mod'):
                try:
                    subprocess.run(['/usr/local/go/bin/go', 'mod', 'init', f'example.com/tmpmod_{idx}'],check=True, capture_output=True,)
                except subprocess.CalledProcessError as e:
                    print("Error running go mod init:")
                    print(e.stderr)
            subprocess.run(['/usr/local/go/bin/go', 'mod', 'tidy'], check=True, capture_output=True,)

        try:
            result = await sandbox().exec(
                cmd=["/usr/local/go/bin/go", "test", file],
                timeout=30,
                cwd=tmp_dir
            )
        except TimeoutError:
            result = ExecResult(False, 1, "", "Verification timed out.")

        shutil.rmtree(tmp_dir)

        return Score(
            value=CORRECT if result.success else INCORRECT,
            explanation="".join(
                ["The following verification code was executed:\n\n"]
                + [final_code]
                + [f"\nThe submission was incorrect\n\n{result.stderr}"]
                if not result.success
                else [""]
            ),
            metadata={
                'completion': model_completion,
                'processed': processed_completion,
                'final_code': final_code,
                'idx': idx,
                'task_id': task_id,
            },
        )

    return score

In [None]:
lang = 'go'

def humaneval_record_to_sample(record):
    model_input = HUMANEVAL_INSTRUCTION + LANG_PREFIX[lang] + '\n' + record['prompt'] 

    idx = int(record['task_id'].split('/')[-1])

    metadata = {
        "prompt": record["prompt"],
        "test": record["test"],
    }
    if lang == 'go':
        metadata['import'] = go_content[idx]['import']
        metadata['test_setup'] = go_content[idx]['test_setup']
    
    return Sample(
        id=record["task_id"],
        input=model_input,
        target=record["canonical_solution"],
        metadata=metadata,
    )

humaneval_dataset = hf_dataset(
    path = 'THUDM/humaneval-x',
    name = lang,
    split = 'test',
    sample_fields = humaneval_record_to_sample,
    trust = True,
)

In [None]:
samples = 164

@task
def humaneval():
    return Task(
        dataset = humaneval_dataset[:samples],
        solver = generate(),
        scorer = go_scorer(),
        sandbox = 'local',
    )

In [None]:
epochs = 1
inspect_ai.eval(humaneval(), model = model, epochs = epochs)
inspect_ai.eval(humaneval(), model = 'openai/gpt-4o-mini', epochs = epochs)

#### Checking outputs

In [None]:
idx = 1

log = read_eval_log('/root/srf-project/logs/test_3.eval')
data = log.samples[idx].scores['go_scorer'].metadata
canonical = log.samples[idx].target

In [None]:
print(identify_codeblock(data['completion']))

In [None]:
print(data['processed'])

In [None]:
processed = find_code_go(data['completion'])
print(processed)

In [None]:
target_id = 145
idx = 0
model_out = ''

for i, sample in enumerate(log.samples):
    id = int(sample.id.split('/')[-1])
    if id == target_id:
        idx = i
        model_out = sample.messages[-1].content
        print(i)
        print()
        print(model_out)
        break
