In [5]:
"""
Part of this code is adapted from "https://github.com/tu-nv/ibn_llm/tree/master". 
Particularly the usage of datasets, associated prompt, computation of accuracy and leveraging the MaxMarginalRelevanceExampleSelector tool. 
We acknowledge and thank the original authors.
The datasets here (Formal specification and NFV configuration) are the existing datasets that we refered to in our paper.
"""

import sys
import time
import pandas as pd
import os
from langchain_chroma import Chroma
from langchain_core.example_selectors import MaxMarginalRelevanceExampleSelector
#from langchain_community.embeddings import OllamaEmbeddings
from langchain_ollama import OllamaEmbeddings
from ollama import Client
import json
import numpy as np

In [None]:
ollama_embedding_url="http://localhost:11434"
ollama_server_url="http://localhost:11435"
client = Client(host=ollama_server_url , timeout=120)
context_examples=[0, 1, 3, 6, 9]
use_case = "formal_spec" #or "nfv_conf"

if use_case == "formal_spec":
    from formal_specification.dataset import trainset, testset
    from formal_specification.prompts import SYSTEM_PROMPT
    from formal_specification.utils import compare_result
elif use_case == "nfv_conf":
    from nfv_configuration.dataset import trainset, testset
    from nfv_configuration.prompts import SYSTEM_PROMPT
    from nfv_configuration.utils import compare_result
else:
    raise ValueError("Invalid use case")

my_models = [
"codegemma",
"starcoder2",
"dolphin-mistral",
"wizardlm2",
"phi",
"yi",
"command-r",
"orca-mini", "llava-llama3", "zephyr",
"starcoder", "codestral",
"codellama:34b",
"codellama",
"llama2",
"llama3",
"llama3.1",
"llama3.2",
"qwen",
"qwen2",
"qwen2.5",
"gemma2:27b",
"openchat",
"marco-o1",
"mistral",
"phi3",
"huihui_ai/qwq-abliterated",
"huihui_ai/qwq-fusion",
"qwq",
"mistral-nemo",
"tinyllama",
"deepseek-coder"
]

#"deepseek-coder-v2" does not support kshift
#"mixtral", "qwen2.5-coder","llava.",
#(does not support) "mxbai-embed-large", "nomic-embed-text", "snowflake-arctic-embed"

my_models_large = ["llama3.3"]
default_model = "llama2"
ollama_emb = OllamaEmbeddings(
    model=default_model,
    base_url=ollama_embedding_url,
)

csv_file = "translate_result_"+use_case+".csv"
pd.DataFrame(columns=["model", "num_examples", "accuracy", "avg_time"]).to_csv(csv_file, index=False)

In [None]:
for model in my_models:

    for num_examples in context_examples:
        
        example_selector = MaxMarginalRelevanceExampleSelector.from_examples([trainset[0]], ollama_emb, Chroma, input_keys=["instruction"], k=num_examples, vectorstore_kwargs={"fetch_k": min(num_examples, len(trainset))} )
        example_selector.vectorstore.reset_collection()
        for example in trainset:
            example_selector.add_example(example)

        correct = 0
        total = 0
        processing_times = []

        for testcase in testset:
            intent = testcase["instruction"]
            expected_output = testcase["output"]
            system_prompt = SYSTEM_PROMPT

            while True:
                try:
                    time.sleep(0.1)
                    current_time = time.time()
                    if num_examples > 0:
                        examples = example_selector.select_examples({"instruction": intent})
                        example_str = "\n\n\n".join(map(lambda x: "Input: " + x["instruction"] + "\n\nOutput: " + x["output"], examples))
                        system_prompt += example_str + "\n\n\n"

                    response = client.generate(model=model,
                        options={
                            'temperature': 0.6,
                            'num_ctx': 8192,
                            'top_p': 0.3,
                            'num_predict': 1024,
                            'num_gpu': 99,
                            },
                        stream=False,
                        system=system_prompt,
                        prompt=intent,
                        format='json'
                    )
                    actual_output = response['response']
                    #print("\nGot Response\n")

                    proc_time_s = (time.time() - current_time)
                    processing_times.append(proc_time_s)
                    break
                except Exception as e:
                    print("Exception on Input: ", e)
                    sys.stdout.flush()
                    continue

            try:
                expected_output = json.loads(expected_output)
                actual_output = json.loads(actual_output)
                num_correct_translation, total_translation = compare_result(expected_output, actual_output)

                #if num_correct_translation == 0:
                    #print(f"Input: {intent}")
                    #print(f"Expected: {expected_output}")
                    #print(f"Actual: {actual_output}")
                    #print(f"Diff: {jsondiff.diff(expected_output, actual_output)}")
                correct += num_correct_translation
                total += total_translation

                #print("=====================================")
                #print(f"Corrects: {correct}, total: {total}, percent: {(correct/total)*100}, proc time: {proc_time_s}")
                #sys.stdout.flush()
            except Exception as e:
                print("Exception on comparing result: ", e)

        print("=====================================")
        print(f"Finish eval on use case: {use_case}, model: {model}, num context examples: {num_examples}, testcases: {total}, accuracy: {round((correct/total)*100, 3)}  avg proc time: {round(np.average(processing_times), 1)}")
        
        try:
            # Assume this block runs after successful computation
            my_result = {
                        "model": model,
                        "num_examples": num_examples,
                        "accuracy": round((correct / total) * 100, 2),
                        "avg_time": round(np.average(processing_times), 1)
                    }

            # Append the new result to the CSV file
            pd.DataFrame([my_result]).to_csv(csv_file, mode='a', header=False, index=False)        

        except Exception as e:
            print(f"Error processing model {model} with {num_examples} examples: {e}")             
        
        sys.stdout.flush()

        if(correct == total):
            break