In [40]:
import math
import openai
import os
import sklearn.metrics
import random
import time
from langchain.llms import OpenAI

In [41]:
os.environ["OPENAI_API_KEY"] = "sk-xxxx"
random.seed(0)

In [42]:
from task_utils import TASKS, load_data, load_prompt, generate_prompts

In [53]:
## your LLM stack goes here


# example of inference function for penai models
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
GPT_TURBO = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.5, max_tokens=600)

def call_llm_openai(prompt):
    output = GPT_TURBO([HumanMessage(content=prompt)])
    return output.content


# example of inference function for modal hosted models
import requests

def call_llm_modal(prompt):
    r = requests.post('https://xxxxxxxxx.modal.run', json={'question': prompt})
    output_dict = r.json()
    output = output_dict['output']
    completion = output[len(prompt):].strip()
    return completion


# example of inference function for baseten hosted models
import baseten
MODEL = baseten.deployed_model_version_id('xxxxx')

def call_llm_baseten(prompt):
    output = MODEL.predict({"prompt": prompt, "do_sample": True, "max_new_tokens": 300})
    completion = output['data']['generated_text'][len(prompt):].strip()
    return completion


call_llm = call_llm_modal


In [54]:
def evaluate(tasks: list, tasks_dir: str):
    report = dict()
    for task in tasks:
        train_df, test_df = load_data(task=task, tasks_dir=tasks_dir)
        prompt_template = load_prompt(prompt_name="base_prompt.txt", task=task, tasks_dir=tasks_dir)
        prompts = generate_prompts(prompt_template=prompt_template, data_df=train_df)
        report[task] = dict()
        targets = list()
        outputs = list()
        print('task', task)
        for prompt, data in zip(prompts, train_df.iterrows()):
            datapoint_id, data = data
            output = call_llm(prompt)
            output = output.strip()
            targets.append(data['answer'])
            outputs.append(output)
            success = output == data['answer']
            report[task][datapoint_id] = {
                'prompt': prompt,
                'generated_output': output,
                'correct_output': data['answer'],
                'success': output == data['answer']
            }
        report[task]['balanced_accuracy'] = sklearn.metrics.balanced_accuracy_score(targets, outputs)
        print('task balanced accuracy:', report[task]['balanced_accuracy'])
        print()
    
    print('Total Balanced Accuracy:', sum([report[task]['balanced_accuracy'] if not math.isnan(report[task]['balanced_accuracy']) else 0 for task in tasks])/len(tasks))
    
    return report


In [56]:
tasks_dir = '../legalbench'

# warnings are to be expected

random.seed(0)
report = evaluate(tasks=random.sample(TASKS, 5), tasks_dir=tasks_dir)

task maud_change_in_law:__subject_to_"disproportionate_impact"_modifier
<Response [200]>
request took 3.805708885192871
task balanced accuracy: 1.0

task maud_fiduciary_exception:__board_determination_standard
<Response [200]>
request took 1.0872490406036377
task balanced accuracy: 1.0

task contract_nli_notice_on_compelled_disclosure
<Response [200]>
request took 147.55259203910828
<Response [200]>
request took 152.9378478527069
<Response [200]>
request took 151.59520721435547
<Response [200]>
request took 154.38736081123352
<Response [200]>
request took 152.96601271629333
<Response [200]>
request took 180.8879361152649
<Response [200]>
request took 153.9524908065796
<Response [200]>
request took 160.44596099853516
task balanced accuracy: 0.0

task diversity_6




<Response [200]>
request took 133.1644549369812
<Response [200]>
request took 134.0045039653778
<Response [200]>
request took 131.24257731437683
<Response [200]>
request took 145.99635815620422
<Response [200]>
request took 133.32334804534912
<Response [200]>
request took 133.69695782661438
task balanced accuracy: 0.0

task opp115_third_party_sharing_collection




<Response [200]>
request took 115.00699591636658
<Response [200]>
request took 111.51435327529907
<Response [200]>
request took 112.91970610618591
<Response [200]>
request took 113.22303605079651
<Response [200]>
request took 120.68412208557129
<Response [200]>
request took 112.27995800971985
<Response [200]>
request took 123.01961588859558
<Response [200]>
request took 195.04768300056458




task balanced accuracy: 0.0

Total Balanced Accuracy: 0.4
