In [17]:
import mlflow
import pandas as pd
from openai import OpenAI
import os
from wandb.sdk.data_types.trace_tree import Trace
import wandb
import configparser
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

### Run this notebook with caution - it costs money. This notebook contains automated model runs: all 5 models are run, on all input sentences and results are logged as csv and in Weights & Biases

### Config Parameters

In [22]:
config = configparser.ConfigParser()
# Read the configuration file
config.read('config.ini')
api_key_openai = config.get('credentials', 'api_key_openai')
api_key_mistral = config.get('credentials', 'api_key_mistral')
surfdrive_url_input_sentences = config.get('credentials', 'surfdrive_url_input_sentences')
surfdrive_url_prompts = config.get('credentials', 'surfdrive_url_prompts')
output_chat_data_folder_path = 'output_llm_data/'

### Input & Prompts

In [6]:
neutral_sentences = pd.read_csv(surfdrive_url_input_sentences,sep=';')['sentences']
neutral_sentences

0     I’m all about that food. I usually kick off th...
1     Just getting my vitamins in at the school cant...
2     De Pizzabakkers sell this vegan pizza with che...
3     I entered the world of vegan foods lately. Nex...
4     This vegan fried chicken from KFC is on the sp...
5     Just having this vegan hotdog from the school ...
6     Just had my first vegan cake at "groene bakker...
7     This vegan chocolate is on point. Its with oat...
8     I'm all about fast food. I'm having some fries...
9     I went to the "Groene Burger" fast food restau...
10    I am at the food feestje in "Koningsplein". I ...
11    I just ate a vegan schnitzel. The flavors and ...
12    The school canteen has vegan Buddha bowls late...
13    I am all over this Italian restaurant. I had s...
14    I am having some new food: vegan Magnum. Just ...
15    I'm going through the city. I am going for som...
16    I am having a midday snack later outside. Soy ...
17    I am all about hot chocolate and some bake

In [18]:
df_prompts = pd.read_csv(surfdrive_url_prompts,sep=';').reset_index()
df_prompts

Unnamed: 0,index,promptID,promptContent
0,0,0,Here is some text {}. Here is a rewrite of the...


### Automated Model Runs

In [28]:
# Mistral pre-setup
mistral_model = ["mistral-tiny", "mistral-small", "mistral-medium"]
mistral_client = MistralClient(api_key = api_key_mistral)

# GPT pre-setup
gtp_client = OpenAI(api_key = api_key_openai)
gpt_models = ["gpt-3.5-turbo","gpt-4"]
gpt_system_msg = "You are an expert in text style transfer."
gpt_temperature=0.2
gpt_max_tokens=256
gpt_frequency_penalty=0.0

# getting the relevant prompts
prompt_content = df_prompts['promptContent'][0]
prompt_id = str(df_prompts['promptID'][0])

# Mistral Runs
for mistral_m in mistral_model:
    print("run-" + "promptID_" + prompt_id + '_model_'+ mistral_m)
    final_output = []
    for i in range(0,len(neutral_sentences)-1):   
        original = neutral_sentences[i]
        query = f"{prompt_content.replace('{}', f'{{{original}}}')}"
        
        messages = [ ChatMessage(role = "user", content = query) ]
        
        # No streaming
        chat_response = mistral_client.chat(
            model = mistral_m,
            messages = messages,
        )
        
        final_output.append({'original': original,'output': chat_response.choices[0].message.content,"model": chat_response.model, "prompt_tokens" : chat_response.usage.prompt_tokens,"completion_tokens" : chat_response.usage.completion_tokens,"object" : chat_response.object, "promptID" : prompt_id})

    df_mistral_output = pd.DataFrame(final_output)
    df_mistral_output.to_csv(output_chat_data_folder_path + "run-" + "promptID_" + prompt_id + '_model_'+ mistral_m + '_output.csv', index=False)
    
    wandb.init(project="lmm-evaluate", name="run-" + "promptID_" + prompt_id + '_model_'+ mistral_m)
    # log df as a table to W&B for interactive exploration
    wandb.log({"promptID_" + prompt_id + '_model'+ mistral_m: wandb.Table(dataframe = df_mistral_output)})
    # log csv file as an dataset artifact to W&B for later use
    artifact = wandb.Artifact('df_' + "promptID_" + prompt_id + '_model_'+ mistral_m + '_output', type="dataset")
    artifact.add_file(output_chat_data_folder_path + "run-" + "promptID_" + prompt_id + '_model_'+ mistral_m + '_output.csv')
    wandb.log_artifact(artifact)
    wandb.finish()


#GPT runs
for gpt_m in gpt_models:    
    print("run-" + "promptID_" + prompt_id + '_model_'+ gpt_m)
    final_output = []
    for i in range(0,len(neutral_sentences)-1):          
        original = neutral_sentences[i]
        query = f"{prompt_content.replace('{}', f'{{{original}}}')}"
        
        message=[{"role": "system", "content": gpt_system_msg}, {"role": "user", "content":query}]
        
        chat_response = gtp_client.chat.completions.create(
            model = gpt_m,
            messages = message,
            temperature = gpt_temperature,
            max_tokens = gpt_max_tokens,
            frequency_penalty = gpt_frequency_penalty
        )
        final_output.append({'original': original,'output': chat_response.choices[0].message.content,"model": chat_response.model, "prompt_tokens" : chat_response.usage.prompt_tokens,"completion_tokens" : chat_response.usage.completion_tokens,"object" : chat_response.object, "promptID" : prompt_id, "temperature": gpt_temperature})
        
    df_gpt_output = pd.DataFrame(final_output)
    df_gpt_output.to_csv(output_chat_data_folder_path + "run-" + "promptID_" + prompt_id + '_model_'+ gpt_m + '_output.csv', index=False)

    wandb.init(project="lmm-evaluate", name="run-" + "promptID_" + prompt_id + '_model_'+ gpt_m)
    # log df as a table to W&B for interactive exploration
    wandb.log({"promptID_" + prompt_id + '_model'+ gpt_m: wandb.Table(dataframe = df_gpt_output)})
    # log csv file as an dataset artifact to W&B for later use
    artifact = wandb.Artifact('df_' + "promptID_" + prompt_id + '_model'+ gpt_m + '_output', type="dataset")
    artifact.add_file(output_chat_data_folder_path + "run-" + "promptID_" + prompt_id + '_model_'+ gpt_m + '_output.csv')
    wandb.log_artifact(artifact)
    
    wandb.finish()    



