In [2]:
import os
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from tqdm.auto import tqdm
from typing import Optional

with open("./openrouter.key", "r") as f:
    os.environ["OPENROUTER_API_KEY"] = f.read().strip()


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ROUNDS = 1
models = list(pd.read_feather("./datasets/models.feather")["id"])

In [4]:
class ChatOpenRouter(ChatOpenAI):
    openai_api_base: str
    openai_api_key: str
    model_name: str

    def __init__(self,
                 model: str,
                 openai_api_key: Optional[str] = None,
                 openai_api_base: str = "https://openrouter.ai/api/v1",
                 **kwargs):
        openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
        super().__init__(openai_api_base=openai_api_base,
                         openai_api_key=openai_api_key,
                         model_name=model, **kwargs)

gen_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{prompt}"),
    ]
)

In [5]:
dataset = "revised_dataset"
df = pd.read_feather(f"./datasets/{dataset}.feather")

In [33]:
def generate_predictions(count, total, model, model_name):
    gen_llm = ChatOpenRouter(
        temperature=0.7,
        model = model,
        cache=False
    )
    generator = gen_prompt | gen_llm

    for round in range(ROUNDS):
        # check if already in folder
        if os.path.exists(f"./results/{dataset}_{model_name}_{round}.pkl"):
            count += len(df)
            print(f"Skipping {dataset}_{model_name}_{round}.pkl")
            continue
        print(f"{dataset}_{model_name}_{round}")
        gens = [None] * len(df)
        try:
            for index, output in tqdm(generator.batch_as_completed(list(df["Prompt"]), return_exceptions=True), total=len(df)):
                if isinstance(output, ValueError): # special case for ValueError in GPT models 
                    message = output.args[0]["message"]
                    if not "flagged" in message:
                        raise ValueError(message)
                    gens[index] = message
                    count += 1
                else:
                    gens[index] = output.content
                    count += 1
            # save to df
            print(f"Saving {dataset}_{model_name}_{round}.pkl")
            print(f"progress: {count}/{total}")
            res_df = df.copy()
            res_df['answer'] = gens
            res_df.to_pickle(f"./results/{dataset}_{model_name}_{round}.pkl")
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            print(f"skipping {dataset}_{model_name}_{round}")
            print(f"progress: {count}/{total}")
            break
count = 0
total = len(df) * len(models) * ROUNDS
for model_index, model in enumerate(models):
    model_name = model.replace("/", "_")
    # Set up pipeline
    generate_predictions(count, total, model, model_name)