In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os
from tqdm import tqdm  

def read_apis(filename):
    with open(filename, 'r') as file:
        apis = [line.strip() for line in file.readlines()]
    return apis

def generate_and_save(api_name, save_path, apis):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    model_id = "./2_finetune/combined_model" #change this to the path of the fine-tuned model
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=quantization_config,
        device_map="auto",
    )

    # File to record the names of saved files where tags were not found
    log_filename = "./log.txt"

    for i in tqdm(range(1000), desc=f"Generating for {api_name}"):
        # Cycle through the API list
        api = apis[i % len(apis)]
        prompt = f"<s>[INST] Generate code snippet that calls the '{api}' API. [/INST]"
        
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
        outputs = model.generate(
            inputs["input_ids"],
            max_new_tokens=200,
            do_sample=True,
            top_p=0.9,
            temperature=0.1,
        )
        outputs = outputs[0].to("cpu")
        decoded_output = tokenizer.decode(outputs)
        
        # Try to extract the first [ANS] to [INST] block
        ans_start_idx = decoded_output.find("[/INST]") + 7
        inst_start_idx = decoded_output.find("[INST]", ans_start_idx)
        
        if ans_start_idx == -1 or inst_start_idx == -1:
            print("block not found, saving full output.")
            code_between = decoded_output
            file_path = os.path.join(save_path, f"{i+1}_{api_name}_{api}.py")
            with open(log_filename, "a") as log_file:
                log_file.write(f"{file_path}\n")
        else:
            code_between = decoded_output[ans_start_idx:inst_start_idx].strip()
            file_path = os.path.join(save_path, f"{i+1}_{api_name}_{api}.py")
            # Save the output to a Python file
            with open(file_path, "w") as file:
                file.write(code_between)
            print(f"Saved: {file_path}")

# Load API names from files
mlx_apis = read_apis("./2_finetune/api_info_extraction/api_list/mlx_api.txt")
mindspore_apis = read_apis("./2_finetune/api_info_extraction/api_list/mindspore_api.txt")
oneflow_apis = read_apis("./2_finetune/api_info_extraction/api_list/oneflow_api.txt")

apis_dict = {
    "MLX": mlx_apis,
    "MindSpore": mindspore_apis,
    "OneFlow": oneflow_apis
}

base_path = "./randomgen" #change this to the path of the folder to save the generated codes

# Generate code for each API using the list of APIs from files
for api_name, apis in apis_dict.items():
    generate_and_save(api_name, os.path.join(base_path, api_name), apis)
