# 🎓 FrugalGPT Experiment on 5 Dataset: Performance and Cost Tradeoffs against ThriftLLM

This notebook illustrates the FrugalGPT framework for _building LLM Applications with budget constraints._

In particular, we will focus on evaluating the performance and cost tradeoffs enabled by FrugalGPT.

NB: You are highly suggested to use accelerated hardware (GPU/TPU) to run this notebook.

## Installation

In [1]:
%load_ext autoreload
%autoreload 2
import sys, json, copy
import pandas as pd
import logging
logging.disable(logging.CRITICAL)
sys.path.append("src/")

## Setup
Next, let us set up the environment and API keys. You do _not_ need API keys to run the notebook! They are only needed if you want to use FrugalGPT for your own queries.

NB: For your own queries, not all API keys are needed, too. If you only want to leverage LLMs from, e.g., OpenAI and AI21, setting up API keys for them is sufficient.

In [2]:
import os
from IPython.display import display
import FrugalGPT
import numpy
from tqdm import tqdm

supported_LLM = FrugalGPT.getservicename()
print("supported LLMs:",supported_LLM)
supported_LLM_names = [llm.split("/")[1] for llm in supported_LLM]
print("supported_LLM_names:", supported_LLM_names)



supported LLMs: ['google/gemini-1.5-flash-002', 'google/gemini-1.5-pro-002', 'google/gemini-1.0-pro', 'openaichat/gpt-4o-mini', 'openaichat/gpt-4o', 'azure/Phi-3-mini-4k-instruct', 'azure/Phi-3.5-mini-instruct', 'azure/Phi-3-small-8k-instruct', 'azure/Phi-3-medium-4k-instruct', 'deepinfra/llama-3-8B', 'deepinfra/llama-3-70B', 'deepinfra/mixtral-8x7B']
supported_LLM_names: ['gemini-1.5-flash-002', 'gemini-1.5-pro-002', 'gemini-1.0-pro', 'gpt-4o-mini', 'gpt-4o', 'Phi-3-mini-4k-instruct', 'Phi-3.5-mini-instruct', 'Phi-3-small-8k-instruct', 'Phi-3-medium-4k-instruct', 'llama-3-8B', 'llama-3-70B', 'mixtral-8x7B']


## Generating the tradeoffs involves three major steps: (i) prepare the dataset, (ii) train the FrugalGPT strategy, and (iii) evaluate and save the performance.

## Step 1: Prepare the dataset

In [3]:
dataname = "HEADLINES"
# dataname = "OVERRULING"
# dataname = "AGNEWS"


In [4]:
# read from data/{dataname}/Queried_{dataname}_all_models_clean_train.csv and data/{dataname}/Queried_{dataname}_all_models_clean_test.csv
dataset_df = pd.read_csv(f'data/{dataname}/Queried_{dataname}_all_models_clean_train.csv', header=0)
dataset_df.head()

Unnamed: 0,query_raw,query,ref_answer,gpt-4o-mini,gpt-4o,llama-3-8B,llama-3-70B,mixtral-8x7B,gemini-1.5-flash-002,gemini-1.0-pro,gemini-1.5-pro-002,Phi-3.5-mini-instruct,Phi-3-small-8k-instruct,Phi-3-mini-4k-instruct,Phi-3-medium-4k-instruct
0,Q: #39;Breakthrough #39; on hydrogen fuel US ...,"Please answer which category (World, Sports, B...",sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,sci/tech
1,Q: Firefox - Ready To Take On Internet Explore...,"Please answer which category (World, Sports, B...",sci/tech,sci/tech,sci/tech,sci/tech,sci/tech,business,sci/tech,sci/tech,sci/tech,business,sports,sci/tech,sci/tech
2,"Q: Facing a fund gap Lucent Technologies"" popu...","Please answer which category (World, Sports, B...",business,business,business,business,business,business,business,business,business,business,business,business,business
3,Q: PeopleSofts big bash See you next year in L...,"Please answer which category (World, Sports, B...",business,business,business,business,business,business,business,business,business,business,business,business,business
4,"Q: Attackers shoot, burn villagers in east Con...","Please answer which category (World, Sports, B...",world,world,world,world,world,world,world,world,world,world,world,world,world


In [5]:
train_data = []
for index, row in dataset_df.iterrows():
    query = row['query']
    ref_answer = row['ref_answer']
    _id = index
    model_answer = {}
    for model_name in supported_LLM_names:
        model_answer[model_name] = row[model_name]
    train_data.append([query, ref_answer, _id, model_answer])

In [6]:
train_data[3]

['Please answer which category (World, Sports, Business or Sci/Tech) a provided news follows into.\n\nQ: Five-year ban for Blackburn fan One of the two Blackburn Rovers Football Club fans charged with public disorder for racially abusing Dwight Yorke has been handed a five-year ban.\nA: Sports\n\nQ: Major software pirates caught A multimillion-euro software piracy ring has been broken following synchronized raids in Athens and London yesterday, Attica police said.\nA: Sci/Tech\n\nQ: PeopleSofts big bash See you next year in Las Vegas , proclaimed a marquee at the PeopleSoft user conference in San Francisco in late September. It was one of many not-so-subtle attempts by the company to reassure its customers \nA:',
 'business',
 3,
 {'gemini-1.5-flash-002': 'business',
  'gemini-1.5-pro-002': 'business',
  'gemini-1.0-pro': 'business',
  'gpt-4o-mini': 'business',
  'gpt-4o': 'business',
  'Phi-3-mini-4k-instruct': 'business',
  'Phi-3.5-mini-instruct': 'business',
  'Phi-3-small-8k-inst

In [7]:
# get the answer of the model llama-3-8B
train_data[3][3]['llama-3-8B']

'business'

## Step 2: Train the FrugalGPT strategy for different budgets

In [8]:
service_names = ['openaichat/gpt-4o-mini',
                'openaichat/gpt-4o',
                'google/gemini-1.5-flash-002',
                'google/gemini-1.5-pro-002',
                'google/gemini-1.0-pro',
                'azure/Phi-3-mini-4k-instruct',
                'azure/Phi-3.5-mini-instruct',
                'azure/Phi-3-small-8k-instruct',
                'azure/Phi-3-medium-4k-instruct',
                'deepinfra/llama-3-8B',
                'deepinfra/llama-3-70B',
                'deepinfra/mixtral-8x7B',
                ]

In [9]:
genparams=FrugalGPT.GenerationParameter(max_tokens=50, temperature=0.1, stop=['\n'])

In [10]:
name = f'{dataname}_1015'
budget_list = [0.0001, 0.0005, 0.001] # , 0.00001, 0.00005, 

In [11]:
print(len(train_data))

6080


## Step 3: Evaluate and save the performance

In [12]:
# read from data/{dataname}/Queried_{dataname}_all_models_clean_train.csv and data/{dataname}/Queried_{dataname}_all_models_clean_test.csv
dataset_df_test = pd.read_csv(f'data/{dataname}/Queried_{dataname}_all_models_clean_test.csv', header=0)
dataset_df_test.head()

Unnamed: 0,query_raw,query,ref_answer,gpt-4o-mini,gpt-4o,llama-3-8B,llama-3-70B,mixtral-8x7B,gemini-1.5-flash-002,gemini-1.0-pro,gemini-1.5-pro-002,Phi-3.5-mini-instruct,Phi-3-small-8k-instruct,Phi-3-mini-4k-instruct,Phi-3-medium-4k-instruct
0,Q: America West Backs Away From ATA Bid Americ...,"Please answer which category (World, Sports, B...",business,business,business,business,business,business,business,business,business,business,business,business,business
1,"Q: Compete against your friends, SI experts an...","Please answer which category (World, Sports, B...",sports,sports,sports,sports,sports,sports,sports,sports,sports,sports,sports,sports,sports
2,Q: Oracle expected to push on content manageme...,"Please answer which category (World, Sports, B...",sci/tech,business,sci/tech,business,business,business,business,business,sci/tech,business,business,business,sci/tech
3,"Q: Bosox strike deal with Mirabelli; Yanks, Fl...","Please answer which category (World, Sports, B...",sports,sports,sports,sports,sports,sports,sports,sports,sports,business,sports,business,sports
4,Q: Bonds deserves a quot;C quot; for historic...,"Please answer which category (World, Sports, B...",sports,sports,sports,sports,sports,sports,sports,sports,sports,world,sports,sports,sports


In [13]:
test_data = []
for index, row in dataset_df_test.iterrows():
    query = row['query']
    ref_answer = row['ref_answer']
    _id = index
    model_answer = {}
    for model_name in supported_LLM_names:
        model_answer[model_name] = row[model_name]
    test_data.append([query, ref_answer, _id, model_answer])

In [14]:
test_data[3]

['Please answer which category (World, Sports, Business or Sci/Tech) a provided news follows into.\n\nQ: Five-year ban for Blackburn fan One of the two Blackburn Rovers Football Club fans charged with public disorder for racially abusing Dwight Yorke has been handed a five-year ban.\nA: Sports\n\nQ: Major software pirates caught A multimillion-euro software piracy ring has been broken following synchronized raids in Athens and London yesterday, Attica police said.\nA: Sci/Tech\n\nQ: Bosox strike deal with Mirabelli; Yanks, Flaherty close The Boston Red Sox have signed backup catcher Doug Mirabelli to a two-year deal worth \\$3 million, making him the first of the World Series champions #39; 16 free agents to re-sign.\nA:',
 'sports',
 3,
 {'gemini-1.5-flash-002': 'sports',
  'gemini-1.5-pro-002': 'sports',
  'gemini-1.0-pro': 'sports',
  'gpt-4o-mini': 'sports',
  'gpt-4o': 'sports',
  'Phi-3-mini-4k-instruct': 'business',
  'Phi-3.5-mini-instruct': 'business',
  'Phi-3-small-8k-instru

In [15]:
# get the answer of the model llama-3-8B
test_data[3][3]['llama-3-8B']

'sports'

In [16]:
print(len(test_data))

1520


In [17]:
def generate_dataframe_from_cascade(MyCascade,budget_list, train_data, test_data, genparams,name):
    # Initialize an empty list to store the rows for the DataFrame
    data = []

    # Iterate through the budget list
    for budget in tqdm(budget_list):
        # Load the strategy for the given budget
        MyCascade.load(loadpath=f"strategy/{name}/", budget=budget)
        print("loaded from path:",f"strategy/{name}/")
        print("now the budget is:",budget)

        # Get the completion batch for train data
        print("start train data")
        train_result = MyCascade.get_completion_batch(queries=train_data, genparams=genparams)
        print("train_result:",train_result)
        # Compute the ACC and cost for train data
        train_acc_cost = FrugalGPT.compute_score(train_result)

        # Get the completion batch for test data
        test_result = MyCascade.get_completion_batch(queries=test_data, genparams=genparams)

        # Compute the ACC and cost for test data
        test_acc_cost = FrugalGPT.compute_score(test_result)

        # Create a row with the schema
        row = {
            "Test_acc": test_acc_cost['em'],
            "Test_cost": test_acc_cost['cost'],
            "Test_size": len(test_data),
            "Train_acc": train_acc_cost['em'],
            "Train_cost": train_acc_cost['cost'],
            "Train_size": len(train_data),
            "Budget": budget,
            "Method": "FrugalGPT",
            "Provider": "FrugalGPT",
            "Marker": 1,  # Marker is always 1 for this function
        }

        # Append the row to the data list
        data.append(row)
        display(row)

    # Create the DataFrame from the data list
    df = pd.DataFrame(data)

    return df

In [18]:
MyCascade_eval = FrugalGPT.LLMCascade()
# MyCascade_eval.prefix = prefix

frugalgpt_df = generate_dataframe_from_cascade(MyCascade_eval,
                                               budget_list, train_data, test_data, genparams,
                                               name)
display(frugalgpt_df)
frugalgpt_df.to_csv(f"summary/summary_{dataname}_e8_frugalgpt_2024.csv")



loaded from path: strategy/AGNEWS_1015/
now the budget is: 0.0001
start train data


Collecting results:   0%|          | 0/6080 [00:00<?, ?it/s]

train_result:        _id    answer ref_answer  cost
0        0  sci/tech   sci/tech     0
1        1  sci/tech   sci/tech     0
2        2  business   business     0
3        3  business   business     0
4        4     world      world     0
...    ...       ...        ...   ...
6075  6075     world      world     0
6076  6076  business   business     0
6077  6077    sports     sports     0
6078  6078     world      world     0
6079  6079     world      world     0

[6080 rows x 4 columns]


Collecting results:   0%|          | 0/1520 [00:00<?, ?it/s]

{'Test_acc': 0.8796052631578948,
 'Test_cost': 0.0,
 'Test_size': 1520,
 'Train_acc': 0.9039473684210526,
 'Train_cost': 0.0,
 'Train_size': 6080,
 'Budget': 0.0001,
 'Method': 'FrugalGPT',
 'Provider': 'FrugalGPT',
 'Marker': 1}



loaded from path: strategy/AGNEWS_1015/
now the budget is: 0.0005
start train data


Collecting results:   0%|          | 0/6080 [00:00<?, ?it/s]

train_result:        _id    answer ref_answer  cost
0        0  sci/tech   sci/tech     0
1        1  sci/tech   sci/tech     0
2        2  business   business     0
3        3  business   business     0
4        4     world      world     0
...    ...       ...        ...   ...
6075  6075     world      world     0
6076  6076  business   business     0
6077  6077    sports     sports     0
6078  6078     world      world     0
6079  6079     world      world     0

[6080 rows x 4 columns]


Collecting results:   0%|          | 0/1520 [00:00<?, ?it/s]

{'Test_acc': 0.8868421052631579,
 'Test_cost': 0.0,
 'Test_size': 1520,
 'Train_acc': 0.9082236842105263,
 'Train_cost': 0.0,
 'Train_size': 6080,
 'Budget': 0.0005,
 'Method': 'FrugalGPT',
 'Provider': 'FrugalGPT',
 'Marker': 1}



loaded from path: strategy/AGNEWS_1015/
now the budget is: 0.001
start train data


Collecting results:   0%|          | 0/6080 [00:00<?, ?it/s]

train_result:        _id    answer ref_answer  cost
0        0  sci/tech   sci/tech     0
1        1  sci/tech   sci/tech     0
2        2  business   business     0
3        3  business   business     0
4        4     world      world     0
...    ...       ...        ...   ...
6075  6075     world      world     0
6076  6076  business   business     0
6077  6077    sports     sports     0
6078  6078     world      world     0
6079  6079     world      world     0

[6080 rows x 4 columns]


Collecting results:   0%|          | 0/1520 [00:00<?, ?it/s]

{'Test_acc': 0.8868421052631579,
 'Test_cost': 0.0,
 'Test_size': 1520,
 'Train_acc': 0.9085526315789474,
 'Train_cost': 0.0,
 'Train_size': 6080,
 'Budget': 0.001,
 'Method': 'FrugalGPT',
 'Provider': 'FrugalGPT',
 'Marker': 1}

100%|██████████| 3/3 [4:57:59<00:00, 5959.68s/it]  


Unnamed: 0,Test_acc,Test_cost,Test_size,Train_acc,Train_cost,Train_size,Budget,Method,Provider,Marker
0,0.879605,0.0,1520,0.903947,0.0,6080,0.0001,FrugalGPT,FrugalGPT,1
1,0.886842,0.0,1520,0.908224,0.0,6080,0.0005,FrugalGPT,FrugalGPT,1
2,0.886842,0.0,1520,0.908553,0.0,6080,0.001,FrugalGPT,FrugalGPT,1
