In [None]:
import openai
import os
import json
import requests
import jsonlines
from tqdm import tqdm
import time
from multiprocessing import Pool
from functools import partial

from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
    RetryError
)

In [None]:
MAX_API_RETRY = 10
LLM_MIT_RETRY_SLEEP = 5

In [None]:
data = list(jsonlines.open("../data/wiki_roleplay_multilingual_test_input_w_evidence.jsonl"))

In [None]:
# model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# model = "openchat/openchat-3.5-1210"
model = "mistralai/Mistral-7B-Instruct-v0.2"

In [None]:
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def completion_with_backoff(client, **kwargs):
    try:
        return client.chat.completions.create(**kwargs)
    except Exception as e:
        print(e)

In [None]:
def process(item, model):
    client = openai.OpenAI(
        api_key="your together api key",
        base_url='https://api.together.xyz',
    )
    
    message = [{"role": "system", "content": item['system']}]
    try:
        for prompt in item['prompts']:
            message.append({"role": "user", "content": prompt})
            answer = completion_with_backoff(client, messages=message, model=model, max_tokens=8192)
            answer = answer.choices[0].message.content
            message.append({"role": "assistant", "content": answer})
    except Exception as e:
        pass
    
    if len(message) == 1:
        return None
    else:
        item['messages'] = message
        return item
    
func = partial(process, model=model)

In [None]:
results = []
with Pool(32) as p:
    pbar = tqdm(total=len(data))
    for item in p.imap_unordered(func, data):
        pbar.update(1)
        results.append(item)

In [None]:
with jsonlines.open(f"data/results/wiki_roleplay_multilingual_test_input_w_evidence_{model.split('/')[1]}.jsonl", "w") as f:
    for each in results:
        f.write(each)