In [None]:
import pandas as pd
from omegaconf import OmegaConf
from tqdm import tqdm
import numpy as np


from model import Model
from dataset.worldbench import WorldBenchDataset
from prompts.get_worldbench_prompt import GetWorldBenchPrompt

### Load Dataset

In [3]:
dataset_args = OmegaConf.create({
    "shuffle": True,                    # whether to shuffle the dataset
    "seed": 42,                         # seed for shuffling
    # "num_samples": 5,                   # number of samples to load (for debugging)
})


dataset = WorldBenchDataset(dataset_args)
dataset.load_dataset(category="all")
dataset = dataset.subsets
dataset.head(2)

Unnamed: 0.1,Unnamed: 0,prompt_text,country,example,metric,gt_answer_avg,type
1466,60,What is the amount of carbon dioxide emissions...,Gabon,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,2.333,co2_emissions
1418,12,What is the amount of carbon dioxide emissions...,Burundi,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,0.058,co2_emissions


### Load the model

In [4]:
# create model
MODEL_NAME = "mistralchat"

model_args = OmegaConf.create({
    "model_name": MODEL_NAME,                                                          # name of the model (llam2chat, mistralchat, llama3chat)
    "deployment": {"method": "quantization", "type": "bitsandbytes", "nbits": 4},      # deployment dict, can be None, method: "pruning" (type: "wanda_unstruct", "wanda_struct") or "quantization" (type: "awq", "bitsandbytes", "kvcachequant" with nbits "4" or "8")
    "device": "cuda",                                                   # device to run the model on
    "sampling_method": "greedy",                                         # sampling method for the model (greedy, sampling)
    "max_new_tokens": 64,                                               # maximum number of tokens to generate
    "remove_prompt_from_generated_text": True,                          # whether to remove the prompt from the generated text
})

model = Model(model_args)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Prepare the prompts

In [5]:
prompt_args = {
    "use_chat_template": True,
    "prompt_template": "$model_input",
    "system_message": "I will ask you factual questions about countries. Specifically, I will ask you for the $metric. You will answer as concisely as possible - only answer with the number! First I will give an example with the answer. Then I will ask you my question, and you will provide the answer in the same way.",
    "answer_prefix": "",
    "model_name": MODEL_NAME,
}

get_prompt = GetWorldBenchPrompt(**prompt_args)

# generate prompts
prompts = []
for i, row in tqdm(dataset.iterrows(), total=len(dataset)):
    prompt = get_prompt(row["prompt_text"], row["example"], row["metric"])
    prompts.append(prompt)

dataset["final_prompt"] = prompts
dataset.head(2)

100%|██████████| 2214/2214 [00:00<00:00, 6925.12it/s]


Unnamed: 0.1,Unnamed: 0,prompt_text,country,example,metric,gt_answer_avg,type,final_prompt
1466,60,What is the amount of carbon dioxide emissions...,Gabon,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,2.333,co2_emissions,<s>[INST] I will ask you factual questions abo...
1418,12,What is the amount of carbon dioxide emissions...,Burundi,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,0.058,co2_emissions,<s>[INST] I will ask you factual questions abo...


In [6]:
print(dataset.final_prompt.iloc[0])

<s>[INST] I will ask you factual questions about countries. Specifically, I will ask you for the amount of carbon dioxide emissions in metric tonnes per capita. You will answer as concisely as possible - only answer with the number! First I will give an example with the answer. Then I will ask you my question, and you will provide the answer in the same way.[/INST] Sounds good, will do.</s>[INST] What is the amount of carbon dioxide emissions in metric tonnes per capita for the country Switzerland?  Do not answer in a complete sentence - only provide the number![/INST] 4.04207281475341</s>[INST] What is the amount of carbon dioxide emissions in metric tonnes per capita for the country Gabon?  Do not answer in a complete sentence - only provide the number![/INST]


### Get Model Outputs (Generated Texts)

In [7]:
def process_sample(sample: str) -> dict:
    text = model.generate(sample)
    return text


texts = []

for idx in tqdm(range(len(dataset))):
    processed_sample = process_sample(dataset.iloc[idx]["final_prompt"])
    texts.append(processed_sample)

dataset["model_output"] = texts


dataset.head()


  0%|          | 0/2214 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 2214/2214 [20:49<00:00,  1.77it/s]


Unnamed: 0.1,Unnamed: 0,prompt_text,country,example,metric,gt_answer_avg,type,final_prompt,model_output
1466,60,What is the amount of carbon dioxide emissions...,Gabon,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,2.333,co2_emissions,<s>[INST] I will ask you factual questions abo...,0.014444444444444446</s>
1418,12,What is the amount of carbon dioxide emissions...,Burundi,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,0.058,co2_emissions,<s>[INST] I will ask you factual questions abo...,0.034444444444444446</s>
581,167,What is the maternal mortality ratio as number...,Latin America & the Caribbean (IDA & IBRD coun...,(What is the maternal mortality ratio as numbe...,maternal mortality ratio as number of deaths p...,86.0,maternal_mortality_ratio,<s>[INST] I will ask you factual questions abo...,126</s>
1517,111,What is the amount of carbon dioxide emissions...,Marshall Islands,(What is the amount of carbon dioxide emission...,amount of carbon dioxide emissions in metric t...,2.534,co2_emissions,<s>[INST] I will ask you factual questions abo...,1.44444444444444444444444444444444444444444444...
1231,40,What is the percent of total land area that is...,"Congo, Rep.",(What is the percent of total land area that i...,percent of total land area that is agricultural,31.211,agricultural_land_percent,<s>[INST] I will ask you factual questions abo...,14.44</s>


In [8]:
dataset.iloc[0]["model_output"]

'0.014444444444444446</s>'

### Compute scores and evaluate

In [9]:
def parse(answer_str: str):
    answer_str = answer_str.strip()
    if '[/INST] ' in answer_str:
        answer_str = answer_str.split('[/INST] ')[-1]
    if 'correct answer is: ' in answer_str:
        answer_str = answer_str.split('correct answer is: ')[1].split(' ')[0]

    words = answer_str.split(' ')
    for suffix in ['thousand', 'million', 'billion', 'trillion']:
        suffix_to_num = dict({'million': 1e6, 'billion': 1e9, 'trillion': 1e12, 'thousand': 1e3, 'hundred': 1e2})
        for suffix in suffix_to_num:
            if suffix in words:
                ind = words.index(suffix)
                try:
                    answer_str = str(float(words[ind-1]) * suffix_to_num[suffix])
                except:
                    pass

    for prequel_word in ['approximately', 'about']:
        if prequel_word in words:
            ind = words.index(prequel_word)
            if ind + 1 < len(words):
                answer_str = words[ind+1]
                break

    answer_str = answer_str.replace('<|assistant|>\n', '').replace(' [/INST]', '').replace('*', '')
    answer_str = answer_str.replace('<|im_start|>assistant\n', '').split('</s>')[0].replace(',', '').split('\n')[0].split('<|im_end|>')[0]
    answer_str = answer_str.replace('<|eot_id|>', '')
    
    try: 
        _ = float(answer_str.split(' ')[0].split('\\')[0].split('%')[0])
        answer_str = answer_str.split(' ')[0].split('\\')[0].split('%')[0]
    except: # last attempt, let's take last word
        answer_str = answer_str.split(' ')[-1].split('\\')[0].split('%')[0]

    if len(answer_str) > 0 and answer_str[-1] == '.':
        answer_str = answer_str[:-1]

    try: 
        if np.abs(float(answer_str) - 2020) < 10 or float(answer_str) < 0:
            answer_str = np.nan
    except:
        pass

    return answer_str

def calculate_mean_abs_rel_error(df):
    df['abs_rel_error'] = np.abs(df['model_output_parsed'] - df['gt_answer_avg']) / np.maximum(df['model_output_parsed'], df['gt_answer_avg'])
    df['rel_error'] = (df['model_output_parsed'] - df['gt_answer_avg']) / np.maximum(df['model_output_parsed'], df['gt_answer_avg'])
    return df["abs_rel_error"].mean()  

dataset["model_output_parsed"] = dataset["model_output"].apply(parse)
dataset["model_output_parsed"] = pd.to_numeric(dataset["model_output_parsed"], errors='coerce')

print(f"Mean absolute relative error: {calculate_mean_abs_rel_error(dataset)}")

Mean absolute relative error: 0.35662133142739505


### Compute income and region group disparities

In [None]:
# enter the path to region_and_income.csv
PATH_TO_COUNTRY_DATA = "/path/to/worldbench/region_and_income.csv"

In [None]:

country_df = pd.read_csv(PATH_TO_COUNTRY_DATA)
country_df.replace({'Economy': {'Curaçao': 'Curacao',
                                'São Tomé and Príncipe': 'Sao Tome and Principe'}}, inplace=True)

def test_parse(answer_str: str):
    answer_str = answer_str.strip()
    if '[/INST] ' in answer_str:
        answer_str = answer_str.split('[/INST] ')[-1]
    if 'correct answer is: ' in answer_str:
        answer_str = answer_str.split('correct answer is: ')[1].split(' ')[0]


def success_rate(df, verbose=False):
    
    cc, ctr = 0,0
    for i, row in df.iterrows():
        ctr += 1
        try:
            _ = test_parse(row['model_output'])
            cc += 1
        except:
            if verbose:
                print(row['model_output'])
    return round(cc/ctr * 100,0)

def get_country_info(country):
    try:
        return country_df[country_df['Economy'] == country][['Region', 'Income group']].values[0]
    except:
        return [np.nan, np.nan]
    
def disparity(df, category_name="Income group"):
    # compute disparity as maximum difference in mean_abs_rel_error between any two categories
    disparity = df.groupby(category_name)['abs_rel_error'].mean().max() - df.groupby(category_name)['abs_rel_error'].mean().min()
    # return as percentage rounded to no decimal places
    return disparity

dataset['Region'], dataset['Income group'] = zip(*dataset['country'].apply(get_country_info))

print(f"Disparity by region: {disparity(dataset, 'Region'):.4f}")
print(f"Disparity by income group: {disparity(dataset, 'Income group'):.4f}")

print(f"Success rate: {success_rate(dataset)}%")

Disparity by region: 0.1642
Disparity by income group: 0.1571
Success rate: 100.0%
