In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
import warnings
warnings.filterwarnings("ignore") 
import subprocess
from utils import load_jsonl, extract_code_block, extract_obj, change_variable_types
import numpy as np
from vllm import LLM, SamplingParams        
from transformers import AutoTokenizer                                      
from langchain.prompts import PromptTemplate
from rule_prompt_utils import gurobi_prompt_temp





In [None]:
# load checkpoints and tokenizer

model_path = '/DATA/disk1/cml/MScache/models/oneday88/SIRL-7B'
tensor_parallel_size = 1
solver_name = 'gurobi'
print("Loading model", model_path)
model = LLM(
    model=model_path,
    tensor_parallel_size=tensor_parallel_size,
    trust_remote_code=True
)
print("Model initialized.")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)



In [None]:
# load prompt template and functions for generation
zeroshot_prompt_system = PromptTemplate.from_template(gurobi_prompt_temp['system'])
zeroshot_prompt_user = PromptTemplate.from_template(gurobi_prompt_temp['user'])
def mp_worker(item):
    prompt = [
        {
            "role": "system",
            "content": zeroshot_prompt_system.format(question=item['en_question']).strip()
        },
        {
            "role": "user",
            "content": zeroshot_prompt_user.format(question=item['en_question']).strip()
        }
    ]
    text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
    return text

def generate_with_model(model, prompt, sampling_params):   
    response = model.generate(prompt, sampling_params) 
    result_text = [g.outputs[0].text for g in response]
    return result_text

In [None]:
# Load decode strategy
topk = 1
max_tokens = 8192
repetition_penalty = 1.02 # To avoid the occasional occurrence of repeated tokens
stop_tokens = ["</s>"]

# top-p strategy
sampling_params = SamplingParams(
    n=topk,
    temperature=0.5,
    top_p=0.9,
    max_tokens=max_tokens,
    stop=stop_tokens,
    repetition_penalty=repetition_penalty
)


In [None]:
# check the pass@1 accuracy
def check_result(result_str, item, solver_name='gurobi'):
    sub_answer = item['en_answer']
    # Convert sub_answer to float or None
    sub_answer = None if sub_answer == "No Best Solution" or "-9999" in str(sub_answer) else float(sub_answer)
    
    # Extract code snippet
    code_snippet = extract_code_block(result_str, solver_name)
    if not code_snippet:
        return 2
    
    # Run code snippet
    try:
        result = subprocess.run(['python3', '-c', code_snippet], capture_output=True, text=True, timeout=100)
    except subprocess.TimeoutExpired:
        return 1 if sub_answer is None else 0
    
    # Check if execution failed
    if result.returncode != 0:
        return 3
    
    # Extract solver result
    solver_result = extract_obj(result.stdout,solver_name)
    
    # check the first time
    if solver_result is not None and sub_answer is not None and np.abs(solver_result - sub_answer) / (np.abs(sub_answer) + 1) <= 1e-6:
        return 1
    # Handle infeasible case or numerical mismatch since we ignore the variable types error
    if 'nfeasible' in result.stdout or (solver_result is not None and sub_answer is not None and np.abs(solver_result - sub_answer) / (np.abs(sub_answer) + 1) > 1e-6):
        # Try re-running with modified variables: we ignore the variable types error
        result_str = change_variable_types(result_str) # change the type of variables
        if result_str:
            try:
                code_snippet = extract_code_block(result_str, solver_name)
                result = subprocess.run(['python3', '-c', code_snippet], capture_output=True, text=True, timeout=100)
                if result.returncode == 0:
                    new_result = extract_obj(result.stdout,solver_name)
                    if 'nfeasible' not in result.stdout: # infeasible and Infeasible
                        if new_result is not None and sub_answer is not None and np.abs(new_result - sub_answer) / (np.abs(sub_answer) + 1) < 1e-6:
                            return 1
                        if new_result == sub_answer:
                            return 1
            except subprocess.TimeoutExpired:
                print("over_time")
                return 1 if sub_answer is None else 0
    
    # Handle infeasible case after retry
    if 'nfeasible' in result.stdout:
        return 1 if sub_answer is None else 0
    
    # Final comparison
    if solver_result is not None and sub_answer is not None:
        return 1 if np.abs(solver_result - sub_answer) / (np.abs(sub_answer) + 1) < 1e-6 else 0
    return 1 if solver_result == sub_answer else 0

In [None]:
# if you want to check pass@1 accuracy, please run this cell
# Test the checkpoint
datapath = 'test_data'
testdataset = ['NL4OPT.jsonl', 'MAMO_EasyLP.json', 'MAMO_ComplexLP_revised.json', 'IndustryOR_fixed.json', 'OptMATH_Bench_193.jsonl', 'OptMATH_Bench_166.jsonl','OptiBench.jsonl']
for filepath in testdataset:
    
    # loading data
    print('Loading data', filepath)
    test_data = load_jsonl(os.path.join(datapath, filepath))
    print('Finish Loading')
    
    # generation 
    
    prompt_list = []
    for item in test_data:
        prompt_list.append(mp_worker(item))
    result_strs = generate_with_model(model, prompt_list, sampling_params)
    snippet_package_cor = []
    score = []
    # check the pass@1 accuracy
    
    for result_str, item in zip(result_strs, test_data):
        snippet_package_cor.append(check_result(result_str, item, solver_name))
    result = np.bincount(snippet_package_cor)
    print(f'Numbers of test cases in dataset {filepath}: {sum(result)}')
    print(f'Numbers of pass@1 cases in dataset {filepath}: {result[1]}')
    print(f'pass@1 accuracy for dataset {filepath}: {result[1]}/{sum(result)} = {result[1] / sum(result)}')
    print('-------------------------------------------------------------------')
  

In [None]:
# if you want to check pass@8 accuracy, please run this cell
# Test the checkpoint
datapath = 'test_data'
testdataset = ['NL4OPT.jsonl', 'MAMO_EasyLP.json', 'MAMO_ComplexLP_revised.json', 'IndustryOR_fixed.json', 'OptMATH_Bench_193.jsonl', 'OptMATH_Bench_166.jsonl','OptiBench.jsonl']
for filepath in testdataset:
    
    # loading data
    print('Loading data', filepath)
    test_data = [i for i in load_jsonl(os.path.join(datapath, filepath)) for _ in range(8)]
    print('Finish Loading')
    
    # generation 
    
    prompt_list = []
    for item in test_data:
        prompt_list.append(mp_worker(item))
    result_strs = generate_with_model(model, prompt_list, sampling_params)
    snippet_package_cor = []
    score = []
    snippet_package_tmp=[]
    # check the pass@8 accuracy
    
    result_chunks = [result_strs[i:i + 8] for i in range(0, len(result_strs), 8)]
    test_data_chunks = [test_data[i:i + 8] for i in range(0, len(test_data), 8)]
    for result_chunk, items in zip(result_chunks,test_data_chunks):
        for chunk, item in zip(result_chunk, items):
            snippet_package_tmp.append(check_result(chunk, item, solver_name))
        if 1 in snippet_package_tmp:
            snippet_package_cor.append(1)
        else:
            snippet_package_cor.append(0)
        snippet_package_tmp.clear()
    result = np.bincount(snippet_package_cor)
    print(f'Numbers of test cases in dataset {filepath}: {sum(result)}')
    print(f'Numbers of pass@8 cases in dataset {filepath}: {result[1]}')
    print(f'pass@8 accuracy for dataset {filepath}: {result[1]}/{sum(result)} = {result[1] / sum(result)}')
    print('-------------------------------------------------------------------')
    
