In [None]:
%pwd

'/workspace'

In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from tqdm import tqdm

from openai import OpenAI

client = OpenAI(api_key="")

In [None]:
TEH_i_l = load_dataset("AlanYky/hate-with-instruction-with-label")
TEO_i_l = load_dataset("AlanYky/offensive-with-instruction-with-label")
climate_i_l = load_dataset("AlanYky/climate-with-instruction-with-label")
SUBJ_i_l = load_dataset("AlanYky/subjective-with-instruction-with-label")

TEH_i_s = load_dataset("AlanYky/hate-with-instruction-with-symbol")
TEO_i_s = load_dataset("AlanYky/offensive-with-instruction-with-symbol")
climate_i_s = load_dataset("AlanYky/climate-with-instruction-with-symbol")
SUBJ_i_s = load_dataset("AlanYky/subjective-with-instruction-with-symbol")

TEH_n = load_dataset("AlanYky/hate-no-instruction-with-symbol")
TEO_n = load_dataset("AlanYky/offensive-no-instruction-with-symbol")
climate_n = load_dataset("AlanYky/climate-no-instruction-with-symbol")
SUBJ_n = load_dataset("AlanYky/subjective-no-instruction-with-symbol")

TEH_f = load_dataset("AlanYky/hate-flipped")
TEO_f = load_dataset("AlanYky/offensive-flipped")
climate_f = load_dataset("AlanYky/climate-flipped")
SUBJ_f = load_dataset("AlanYky/subjective-flipped")

In [None]:
data = load_dataset("AlanYky/big-bench-list-function-turing-330")
modify_the_list = data.filter(lambda example: example['category'] == 'modify_the_list')
remove_elements = data.filter(lambda example: example['category'] == 'remove_elements')
input_independent = data.filter(lambda example: example['category'] == 'input_independent')
add_elements = data.filter(lambda example: example['category'] == 'add_elements')
miscellaneous = data.filter(lambda example: example['category'] == 'miscellaneous')
Turing = data.filter(lambda example: example['category'] == 'Turing')

In [None]:
big_bench_prompt = "Apply a function to the final input list to generate the output list. Use any preceding inputs and outputs as examples."

def create_prompt(prompt, big_bench=False, big_bench_prompt=""):
  if big_bench:
    return f"Instruct: {big_bench_prompt}\n{prompt}\nOutput:"
  else:
    return f"Instruct: {prompt}\nOutput:"


def extract_answer(text):
        answer_idx = text.find("Output:")
        if answer_idx != -1:
            return text[answer_idx:].strip()
        return text

def create_messages(prompt, big_bench=False, big_bench_prompt=""):
    if big_bench:
       return [{"role": "user", "content": f"Instruct: {big_bench_prompt}\n{prompt}\nOutput:"}]
    else:
       return [{"role": "user", "content": f"Instruct: {prompt}\nOutput:"}]


def compare_models_on_dataset(dataset, num_data, eval_tokenizer, base_model, instruct_model, symbol_model, file_name, max_new_tokens=100):
    base_model.eval()
    instruct_model.eval()
    symbol_model.eval()
    samples = dataset['train'].shuffle(seed=43).select(range(num_data))
    inputs, targets, base_generations, instruct_generations, symbol_generations = [], [], [], [], []
    count = 0

    for example in tqdm(samples):
        #print(f"======Evaluating sample {count}")
        prompt = create_prompt(example['inputs'])
        model_input = eval_tokenizer(prompt, return_tensors="pt").to("cuda")
        #print(prompt)

        with torch.no_grad():
            base_gen = extract_answer(eval_tokenizer.decode(base_model.generate(**model_input, max_new_tokens=max_new_tokens)[0], skip_special_tokens=True))
            #print("Phi2-base================")
            #print(base_gen)

            instruct_gen = extract_answer(eval_tokenizer.decode(instruct_model.generate(**model_input, max_new_tokens=max_new_tokens)[0], skip_special_tokens=True))
            #print("Phi2-instruct================")
            #print(instruct_gen)

            symbol_gen = extract_answer(eval_tokenizer.decode(symbol_model.generate(**model_input, max_new_tokens=max_new_tokens)[0], skip_special_tokens=True))
            #print("Phi2-symbol================")
            #print(symbol_gen)


        # Append the results
        inputs.append(example['inputs'])
        targets.append(example['target'])
        base_generations.append(base_gen)
        instruct_generations.append(instruct_gen)
        symbol_generations.append(symbol_gen)

    df = pd.DataFrame({
        'Input': inputs,
        'Target': targets,
        'Base Model Generation': base_generations,
        'Instruction Model Generation': instruct_generations,
        'Symbol Model Generation': symbol_generations
    })

    df.to_csv(f'{file_name}.csv', index=False)

    return df

def model_eval(dataset, num_data, eval_tokenizer, symbol_model, file_name, max_new_tokens=50, big_bench=False, big_bench_prompt=""):
    symbol_model.eval()
    samples = dataset['train'].shuffle(seed=43).select(range(num_data))
    inputs, targets, symbol_generations = [], [], []

    for example in tqdm(samples):
        prompt = create_prompt(example['inputs'], big_bench, big_bench_prompt)
        model_input = eval_tokenizer(prompt, return_tensors="pt").to("cuda")

        with torch.no_grad():
            symbol_gen = extract_answer(eval_tokenizer.decode(symbol_model.generate(**model_input, max_new_tokens=max_new_tokens)[0], skip_special_tokens=True))


        # Append the results
        inputs.append(example['inputs'])
        targets.append(example['target'])
        symbol_generations.append(symbol_gen)

    df = pd.DataFrame({
        'Input': inputs,
        'Target': targets,
        'Symbol Model Generation': symbol_generations
    })

    df.to_csv(f'{file_name}.csv', index=False)

    return df

def model_eval_m(dataset, eval_tokenizer, model, file_name, max_new_tokens=50, num_data=25, big_bench=False, big_bench_prompt=""):
    mistral_model.eval()
    samples = (dataset['train'].shuffle(seed=43))

    # Use all data during big_bench eval
    if not big_bench: samples = samples.select(range(num_data))

    inputs, targets, generations = [], [], []

    for example in tqdm(samples):
        model_input = eval_tokenizer.apply_chat_template(create_messages(example['inputs'], big_bench=big_bench, big_bench_prompt=big_bench_prompt), return_tensors="pt").to("cuda")

        with torch.no_grad():
            generated_ids = model.generate(model_input, max_new_tokens=max_new_tokens, do_sample=True)

            generated = extract_answer(eval_tokenizer.batch_decode(generated_ids)[0])
            # print(generated[:30])

        inputs.append(example['inputs'])
        if big_bench:
            targets.append(example['targets'])
        else:
            targets.append(example['target'])
        generations.append(generated)

    df = pd.DataFrame({
        'Input': inputs,
        'Target': targets,
        'Mistral Model Generation': generations
    })

    df.to_csv(f'{file_name}.csv', index=False)

    return df


def model_eval_gpt3(dataset, file_name, max_new_tokens=50, num_data=25, big_bench=False, big_bench_prompt=""):
    samples = (dataset['train'].shuffle(seed=43))
    # Use all data during big_bench eval
    if not big_bench: samples = samples.select(range(num_data))
    inputs, targets, generations = [], [], []

    for example in tqdm(samples):
        completion = client.chat.completions.create(
            model="gpt-3.5-turbo",
            temperature=0,
            max_tokens=max_new_tokens,
            messages=create_messages(example['inputs'], big_bench=big_bench, big_bench_prompt=big_bench_prompt))

        generated = extract_answer(completion.choices[0].message.content)

        inputs.append(example['inputs'])
        if big_bench:
            targets.append(example['targets'])
        else:
            targets.append(example['target'])
        generations.append(generated)

    df = pd.DataFrame({
        'Input': inputs,
        'Target': targets,
        'GPT-3.5 Generation': generations
    })

    df.to_csv(f'{file_name}.csv', index=False)

    return df

# Evaluation using GPT-3.5

In [None]:
eval_sets = [
    TEH_i_l, TEO_i_l, climate_i_l, SUBJ_i_l,
    TEH_i_s, TEO_i_s, climate_i_s, SUBJ_i_s,
    TEH_n, TEO_n, climate_n, SUBJ_n,
    TEH_f, TEO_f, climate_f, SUBJ_f
]

eval_sets_name = [
    "TEH_i_l", "TEO_i_l", "climate_i_l", "SUBJ_i_l",
    "TEH_i_s", "TEO_i_s", "climate_i_s", "SUBJ_i_s",
    "TEH_n", "TEO_n", "climate_n", "SUBJ_n",
    "TEH_f", "TEO_f", "climate_f", "SUBJ_f"
]


for s, name in zip(eval_sets,eval_sets_name):
    model_eval_gpt3(
    dataset=s,
    num_data=25,
    file_name=name+"_gpt3",
    max_new_tokens=50
)

100%|██████████| 25/25 [00:12<00:00,  2.01it/s]
100%|██████████| 25/25 [00:10<00:00,  2.33it/s]
100%|██████████| 25/25 [00:10<00:00,  2.28it/s]
100%|██████████| 25/25 [00:10<00:00,  2.46it/s]
100%|██████████| 25/25 [00:10<00:00,  2.35it/s]
100%|██████████| 25/25 [00:12<00:00,  1.96it/s]
100%|██████████| 25/25 [00:09<00:00,  2.56it/s]
100%|██████████| 25/25 [00:08<00:00,  2.79it/s]
100%|██████████| 25/25 [00:10<00:00,  2.42it/s]
100%|██████████| 25/25 [00:09<00:00,  2.60it/s]
100%|██████████| 25/25 [00:08<00:00,  2.97it/s]
100%|██████████| 25/25 [00:10<00:00,  2.39it/s]
100%|██████████| 25/25 [00:10<00:00,  2.35it/s]
100%|██████████| 25/25 [00:08<00:00,  2.88it/s]
100%|██████████| 25/25 [00:09<00:00,  2.71it/s]
100%|██████████| 25/25 [00:08<00:00,  2.79it/s]


In [None]:
big_bench_sets = [modify_the_list, remove_elements, input_independent, add_elements, miscellaneous, Turing]
big_bench_sets_name = [
    "modify_the_list", "remove_elements", "input_independent", "add_elements", "miscellaneous", "Turing"
]

for s, name in zip(big_bench_sets, big_bench_sets_name):
    model_eval_gpt3(
    dataset=s,
    num_data=25,
    file_name=name+"_gpt3",
    max_new_tokens=50,
    big_bench=True,
    big_bench_prompt=big_bench_prompt
)

  1%|          | 1/120 [00:00<01:19,  1.50it/s]

[33, 1]


  2%|▏         | 2/120 [00:01<01:24,  1.39it/s]

[5, 7, 4, 1, 0, 6, 2, 8, 9, 3]


  2%|▎         | 3/120 [00:01<01:08,  1.71it/s]

[9, 9, 0, 0, 8, 8]


  3%|▎         | 4/120 [00:02<01:05,  1.77it/s]

[5, 46, 84, 22]


  4%|▍         | 5/120 [00:03<01:15,  1.52it/s]

[9, 9, 6, 6, 8, 8, 1, 1, 2, 2,


  5%|▌         | 6/120 [00:03<01:02,  1.82it/s]

[6]


  6%|▌         | 7/120 [00:03<00:57,  1.97it/s]

[89, 89, 89]


  7%|▋         | 8/120 [00:04<00:59,  1.87it/s]

[3, 3, 3, 3, 3, 3]


  8%|▊         | 9/120 [00:05<01:03,  1.76it/s]

[5, 3, 7, 6, 0, 9, 2, 1, 8]


  8%|▊         | 10/120 [00:05<00:59,  1.84it/s]

[72, 72, 72, 72]


  9%|▉         | 11/120 [00:06<00:57,  1.90it/s]

[72, 1, 18, 2, 4, 3]


 10%|█         | 12/120 [00:06<01:04,  1.67it/s]

[30, 30, 30, 30, 30, 30]


 11%|█         | 13/120 [00:07<01:04,  1.66it/s]

[2, 4]


 12%|█▏        | 14/120 [00:08<01:01,  1.72it/s]

[64, 64, 64, 64, 64]


 12%|█▎        | 15/120 [00:08<01:06,  1.58it/s]

[4, 5, 2, 6, 9, 3, 7, 1, 0, 8]


 13%|█▎        | 16/120 [00:09<01:13,  1.41it/s]

[2, 2, 2, 2, 2, 2, 2, 2, 2, 2,


 14%|█▍        | 17/120 [00:10<01:01,  1.68it/s]

[30]


 15%|█▌        | 18/120 [00:10<01:03,  1.59it/s]

[0, 0, 5, 5, 8, 8, 1, 1]


 16%|█▌        | 19/120 [00:12<01:46,  1.05s/it]

[82, 1, 32, 2, 34, 3, 29, 4, 8


 17%|█▋        | 20/120 [00:13<01:24,  1.18it/s]

[8, 7, 2]


 18%|█▊        | 21/120 [00:14<01:29,  1.11it/s]

[54, 1, 17, 2, 29, 3, 22, 4]


 18%|█▊        | 22/120 [00:14<01:20,  1.22it/s]

[5, 5, 5, 5, 5, 5, 5, 5]


 19%|█▉        | 23/120 [00:15<01:22,  1.18it/s]

[5, 5, 5, 5, 5, 5, 5, 5, 0, 0,


 20%|██        | 24/120 [00:16<01:18,  1.22it/s]

[3, 15, 81, 83, 59, 96, 0, 6, 


 21%|██        | 25/120 [00:16<01:07,  1.40it/s]

[92, 92]


 22%|██▏       | 26/120 [00:17<01:05,  1.45it/s]

[0, 64, 86, 83, 10, 59, 18]


 22%|██▎       | 27/120 [00:18<01:03,  1.46it/s]

[8, 8, 6, 6, 7, 7, 3, 3, 5, 5,


 23%|██▎       | 28/120 [00:19<01:04,  1.43it/s]

[97, 1, 4, 2, 7, 3, 73, 4]


 24%|██▍       | 29/120 [00:19<00:54,  1.66it/s]

[6, 1]


 25%|██▌       | 30/120 [00:19<00:51,  1.74it/s]

[33, 37, 6, 57]


 26%|██▌       | 31/120 [00:20<00:54,  1.62it/s]

[20, 1, 2, 2, 25, 3, 49, 4, 7,


 27%|██▋       | 32/120 [00:21<00:51,  1.71it/s]

[3, 2, 6, 7, 5]


 28%|██▊       | 33/120 [00:21<00:47,  1.84it/s]

[3, 1, 8, 4]


 28%|██▊       | 34/120 [00:21<00:42,  2.03it/s]

[9]


 29%|██▉       | 35/120 [00:22<00:48,  1.74it/s]

[1, 1, 6, 6, 0, 0, 2, 2, 7, 7]


 30%|███       | 36/120 [00:23<00:50,  1.66it/s]

[5, 8, 0, 4, 7, 6, 1, 2, 3, 9]


 31%|███       | 37/120 [00:24<00:57,  1.45it/s]

[1, 1, 91, 2, 48, 3, 96, 4, 79


 32%|███▏      | 38/120 [00:25<01:00,  1.36it/s]

[2, 2, 2, 2, 3, 3, 3, 3, 3, 3,


 32%|███▎      | 39/120 [00:25<00:48,  1.68it/s]

[27]


 33%|███▎      | 40/120 [00:26<00:55,  1.44it/s]

[18, 9, 1, 38, 8]


 34%|███▍      | 41/120 [00:26<00:46,  1.69it/s]

[]


 35%|███▌      | 42/120 [00:27<00:52,  1.48it/s]

[3, 3, 3, 3, 3, 3, 3, 3, 3, 3,


 36%|███▌      | 43/120 [00:27<00:46,  1.67it/s]

[39, 8, 4]


 37%|███▋      | 44/120 [00:28<00:45,  1.66it/s]

[52, 52, 52, 52, 52, 52, 52, 5


 38%|███▊      | 45/120 [00:29<00:56,  1.34it/s]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0,


 38%|███▊      | 46/120 [00:30<00:50,  1.47it/s]

[5, 8, 4, 29, 54, 31, 41, 89]


 39%|███▉      | 47/120 [00:30<00:49,  1.47it/s]

[9, 7, 1, 4, 8, 3, 6, 2]


 40%|████      | 48/120 [00:31<00:47,  1.53it/s]

[0, 6, 7, 8, 2, 3, 5, 4, 9]


 41%|████      | 49/120 [00:32<00:46,  1.54it/s]

[3, 5, 8, 4, 1, 0, 2, 7]


 42%|████▏     | 50/120 [00:32<00:41,  1.71it/s]

[]


 42%|████▎     | 51/120 [00:33<00:41,  1.64it/s]

[7, 7, 2, 2, 4, 4, 0, 0]


 43%|████▎     | 52/120 [00:33<00:42,  1.60it/s]

[2, 4, 1, 9, 0, 8, 5, 3, 7]


 44%|████▍     | 53/120 [00:34<00:44,  1.52it/s]

[38, 5, 98, 2, 71, 7, 35, 3, 5


 45%|████▌     | 54/120 [00:35<00:47,  1.38it/s]

[2, 2, 5, 5, 5, 5, 5, 0, 0, 0,


 46%|████▌     | 55/120 [00:36<00:49,  1.31it/s]

[9, 9, 4, 4, 0, 0, 1, 1, 2, 2,


 47%|████▋     | 56/120 [00:36<00:46,  1.38it/s]

[8, 2, 1, 4, 7, 9, 5, 3, 0, 6]


 48%|████▊     | 57/120 [00:37<00:40,  1.55it/s]

[0, 9, 5]


 48%|████▊     | 58/120 [00:38<00:41,  1.51it/s]

[5, 4, 9, 0, 3, 8, 1]


 49%|████▉     | 59/120 [00:38<00:44,  1.38it/s]

[5, 5, 7, 7, 4, 4, 3, 3]


 50%|█████     | 60/120 [00:39<00:43,  1.37it/s]

[78, 1, 22, 2, 29, 3, 8, 4, 7,


 51%|█████     | 61/120 [00:40<00:41,  1.41it/s]

[5, 0, 6, 9, 4, 8, 7]


 52%|█████▏    | 62/120 [00:40<00:33,  1.74it/s]

[97]


 52%|█████▎    | 63/120 [00:41<00:33,  1.73it/s]

[61, 61, 61, 61, 61, 61, 61, 6


 53%|█████▎    | 64/120 [00:41<00:32,  1.71it/s]

[75, 44, 15, 87, 2]


 54%|█████▍    | 65/120 [00:42<00:36,  1.52it/s]

[3, 3, 3, 6, 6, 6, 6, 6, 6]


 55%|█████▌    | 66/120 [00:43<00:39,  1.37it/s]

[16, 1, 17, 2, 89, 3, 8, 4, 4,


 56%|█████▌    | 67/120 [00:43<00:33,  1.58it/s]

[34, 7, 1]


 57%|█████▋    | 68/120 [00:44<00:31,  1.64it/s]

[49, 49, 49, 49, 49, 49]


 57%|█████▊    | 69/120 [00:44<00:26,  1.91it/s]

[74, 91, 87]


 58%|█████▊    | 70/120 [00:45<00:26,  1.88it/s]

[4, 4, 3, 3, 8, 8]


 59%|█████▉    | 71/120 [00:46<00:29,  1.67it/s]

[90, 1, 9, 2, 6, 3, 67, 4, 0, 


 60%|██████    | 72/120 [00:47<00:36,  1.33it/s]

[1, 1, 1, 1, 1, 1]


 61%|██████    | 73/120 [00:47<00:32,  1.44it/s]

[3, 6, 0, 1, 9, 8]


 62%|██████▏   | 74/120 [00:48<00:36,  1.27it/s]

[2, 2, 5, 5, 5, 5, 6, 6, 6, 6,


 62%|██████▎   | 75/120 [00:49<00:33,  1.33it/s]

[5, 2, 3, 7, 0, 8]


 63%|██████▎   | 76/120 [00:49<00:27,  1.59it/s]

[69, 96]


 64%|██████▍   | 77/120 [00:50<00:32,  1.34it/s]

[4, 4, 4, 4, 4, 4, 4, 7, 7, 7,


 65%|██████▌   | 78/120 [00:51<00:31,  1.32it/s]

[79, 1, 3, 2]


 66%|██████▌   | 79/120 [00:52<00:26,  1.56it/s]

[4, 4, 9, 9]


 67%|██████▋   | 80/120 [00:52<00:22,  1.76it/s]

[77, 24]


 68%|██████▊   | 81/120 [00:53<00:27,  1.41it/s]

[97, 1, 16, 44, 2, 83, 5, 88, 


 68%|██████▊   | 82/120 [00:53<00:23,  1.60it/s]

[7, 7]


 69%|██████▉   | 83/120 [00:54<00:25,  1.46it/s]

[2, 2, 2, 5, 5, 5, 3, 3, 3, 5,


 70%|███████   | 84/120 [00:55<00:22,  1.57it/s]

Since the input list is empty,


 71%|███████   | 85/120 [00:55<00:19,  1.79it/s]

[66, 1]


 72%|███████▏  | 86/120 [00:55<00:17,  1.98it/s]

[78, 17, 49]


 72%|███████▎  | 87/120 [00:56<00:15,  2.09it/s]

[43, 65, 77, 13, 54]


 73%|███████▎  | 88/120 [00:57<00:18,  1.76it/s]

[2, 2, 9, 9, 3, 3, 5, 5, 1, 1,


 74%|███████▍  | 89/120 [00:57<00:15,  1.94it/s]

[40, 40]


 75%|███████▌  | 90/120 [00:58<00:17,  1.67it/s]

[9, 2, 6, 4, 7, 0, 3, 5]


 76%|███████▌  | 91/120 [00:58<00:16,  1.72it/s]

[2, 6, 0, 5, 4, 1]


 77%|███████▋  | 92/120 [00:59<00:17,  1.64it/s]

[3, 3, 3, 3, 3, 3, 3, 3]


 78%|███████▊  | 93/120 [01:00<00:20,  1.35it/s]

[3, 3, 3, 3, 3, 3, 3, 3, 3, 3,


 78%|███████▊  | 94/120 [01:00<00:16,  1.59it/s]

[96, 68]


 79%|███████▉  | 95/120 [01:01<00:16,  1.53it/s]

[1, 1, 97, 2, 81, 3, 87, 4]


 80%|████████  | 96/120 [01:02<00:16,  1.47it/s]

[1, 4, 3, 5, 6, 0, 7, 8, 2]


 81%|████████  | 97/120 [01:02<00:14,  1.57it/s]

[8, 1, 52, 2, 70, 3]


 82%|████████▏ | 98/120 [01:03<00:12,  1.77it/s]

[65, 41, 45, 64]


 82%|████████▎ | 99/120 [01:04<00:13,  1.52it/s]

[0, 0, 7, 7, 6, 6, 4, 4, 9, 9,


 83%|████████▎ | 100/120 [01:04<00:11,  1.74it/s]

[9, 5, 0]


 84%|████████▍ | 101/120 [01:05<00:11,  1.62it/s]

[9, 1, 0, 2, 24, 3, 28, 4, 44,


 85%|████████▌ | 102/120 [01:05<00:11,  1.59it/s]

[2, 5, 1, 4, 0, 3, 6, 7]


 86%|████████▌ | 103/120 [01:06<00:10,  1.61it/s]

[4, 0, 6, 5, 2, 9, 7]


 87%|████████▋ | 104/120 [01:07<00:09,  1.75it/s]

[10, 10]


 88%|████████▊ | 105/120 [01:07<00:08,  1.87it/s]

[32, 65, 49]


 88%|████████▊ | 106/120 [01:08<00:08,  1.56it/s]

[3, 3, 1, 1, 4, 4, 8, 8, 0, 0,


 89%|████████▉ | 107/120 [01:09<00:08,  1.52it/s]

[36, 36, 36, 36, 36, 36, 36]


 90%|█████████ | 108/120 [01:09<00:07,  1.63it/s]

[7, 40, 0, 70, 33, 13]


 91%|█████████ | 109/120 [01:10<00:06,  1.69it/s]

[45, 45, 45]


 92%|█████████▏| 110/120 [01:10<00:05,  1.96it/s]

[]


 92%|█████████▎| 111/120 [01:11<00:04,  1.90it/s]

[87, 9, 7, 91, 52, 2, 29]


 93%|█████████▎| 112/120 [01:11<00:03,  2.05it/s]

[50, 50, 50, 50]


 94%|█████████▍| 113/120 [01:12<00:04,  1.73it/s]

[69, 43, 1, 79, 8, 28, 57, 84]


 95%|█████████▌| 114/120 [01:12<00:03,  1.83it/s]

[58, 53, 8, 59, 70]


 96%|█████████▌| 115/120 [01:13<00:02,  1.84it/s]

[83]


 97%|█████████▋| 116/120 [01:13<00:02,  1.86it/s]

[35, 2, 11, 42, 0]


 98%|█████████▊| 117/120 [01:14<00:01,  2.03it/s]

[0, 0, 7, 7]


 98%|█████████▊| 118/120 [01:14<00:00,  2.20it/s]

[5, 72, 8]


 99%|█████████▉| 119/120 [01:14<00:00,  2.25it/s]

[0, 4, 6]


100%|██████████| 120/120 [01:15<00:00,  1.59it/s]


[1, 13, 9, 27, 36, 35]


  1%|          | 1/90 [00:00<00:30,  2.93it/s]

[65]


  2%|▏         | 2/90 [00:00<00:27,  3.18it/s]

[7]


  3%|▎         | 3/90 [00:00<00:25,  3.45it/s]

[93]


  4%|▍         | 4/90 [00:01<00:26,  3.20it/s]

[4]


  6%|▌         | 5/90 [00:01<00:25,  3.28it/s]

[9, 15]


  7%|▋         | 6/90 [00:01<00:26,  3.22it/s]

[33, 9]


  8%|▊         | 7/90 [00:02<00:26,  3.11it/s]

[10, 95]


  9%|▉         | 8/90 [00:02<00:26,  3.07it/s]

[8]


 10%|█         | 9/90 [00:02<00:28,  2.88it/s]

[48, 24]


 11%|█         | 10/90 [00:03<00:26,  3.03it/s]

[28, 27]


 12%|█▏        | 11/90 [00:03<00:24,  3.18it/s]

[37, 3]


 13%|█▎        | 12/90 [00:03<00:23,  3.37it/s]

[5]


 14%|█▍        | 13/90 [00:04<00:30,  2.53it/s]

[24, 72, 7, 95, 73, 53]


 16%|█▌        | 14/90 [00:04<00:30,  2.52it/s]

[75]


 17%|█▋        | 15/90 [00:05<00:28,  2.62it/s]

[9]


 18%|█▊        | 16/90 [00:05<00:26,  2.77it/s]

[81, 80]


 19%|█▉        | 17/90 [00:05<00:25,  2.86it/s]

[3]


 20%|██        | 18/90 [00:06<00:28,  2.54it/s]

[9, 6]


 21%|██        | 19/90 [00:06<00:26,  2.64it/s]

[89]


 22%|██▏       | 20/90 [00:06<00:23,  2.92it/s]

[6]


 23%|██▎       | 21/90 [00:07<00:22,  3.11it/s]

[1]


 24%|██▍       | 22/90 [00:07<00:22,  3.05it/s]

[3]


 26%|██▌       | 23/90 [00:08<00:31,  2.13it/s]

[8]


 27%|██▋       | 24/90 [00:08<00:27,  2.42it/s]

[6]


 28%|██▊       | 25/90 [00:09<00:28,  2.27it/s]

[31, 0, 46, 2]


 29%|██▉       | 26/90 [00:09<00:25,  2.49it/s]

[8]


 30%|███       | 27/90 [00:09<00:25,  2.45it/s]

[2]


 31%|███       | 28/90 [00:10<00:22,  2.79it/s]

[45]


 32%|███▏      | 29/90 [00:10<00:20,  2.96it/s]

[7]


 33%|███▎      | 30/90 [00:11<00:26,  2.27it/s]

[7, 0, 38, 5, 23, 3, 1, 14]


 34%|███▍      | 31/90 [00:11<00:24,  2.43it/s]

[11]


 36%|███▌      | 32/90 [00:11<00:26,  2.17it/s]

[99]


 37%|███▋      | 33/90 [00:12<00:24,  2.36it/s]

[1]


 38%|███▊      | 34/90 [00:12<00:20,  2.72it/s]

[6]


 39%|███▉      | 35/90 [00:13<00:25,  2.18it/s]

[71, 7, 31, 40, 9, 16, 86]


 40%|████      | 36/90 [00:13<00:29,  1.81it/s]

[80]


 41%|████      | 37/90 [00:14<00:31,  1.68it/s]

[53, 34, 39, 85, 8, 87, 2]


 42%|████▏     | 38/90 [00:14<00:27,  1.91it/s]

[9]


 43%|████▎     | 39/90 [00:15<00:28,  1.79it/s]

[7, 53]


 44%|████▍     | 40/90 [00:15<00:24,  2.02it/s]

[9]


 46%|████▌     | 41/90 [00:16<00:21,  2.30it/s]

[17, 1]


 47%|████▋     | 42/90 [00:16<00:20,  2.37it/s]

[5]


 48%|████▊     | 43/90 [00:17<00:18,  2.53it/s]

[78]


 49%|████▉     | 44/90 [00:17<00:17,  2.68it/s]

[12]


 50%|█████     | 45/90 [00:17<00:16,  2.65it/s]

[78, 3]


 51%|█████     | 46/90 [00:18<00:16,  2.73it/s]

[80]


 52%|█████▏    | 47/90 [00:18<00:14,  2.97it/s]

[4]


 53%|█████▎    | 48/90 [00:18<00:13,  3.18it/s]

[72]


 54%|█████▍    | 49/90 [00:18<00:13,  3.04it/s]

[4]


 56%|█████▌    | 50/90 [00:19<00:13,  2.87it/s]

[90]


 57%|█████▋    | 51/90 [00:19<00:13,  2.92it/s]

[28]


 58%|█████▊    | 52/90 [00:20<00:13,  2.85it/s]

[8]


 59%|█████▉    | 53/90 [00:20<00:13,  2.68it/s]

[6, 36]


 60%|██████    | 54/90 [00:20<00:12,  2.93it/s]

[4]


 61%|██████    | 55/90 [00:21<00:12,  2.91it/s]

[39]


 62%|██████▏   | 56/90 [00:21<00:11,  2.86it/s]

[20]


 63%|██████▎   | 57/90 [00:21<00:11,  2.87it/s]

[0]


 64%|██████▍   | 58/90 [00:22<00:12,  2.61it/s]

[5, 38]


 66%|██████▌   | 59/90 [00:22<00:11,  2.59it/s]

[70]


 67%|██████▋   | 60/90 [00:23<00:11,  2.64it/s]

[77]


 68%|██████▊   | 61/90 [00:23<00:10,  2.85it/s]

[5]


 69%|██████▉   | 62/90 [00:23<00:10,  2.72it/s]

[6]


 70%|███████   | 63/90 [00:24<00:13,  1.95it/s]

[99]


 71%|███████   | 64/90 [00:24<00:11,  2.29it/s]

[9]


 72%|███████▏  | 65/90 [00:25<00:09,  2.56it/s]

[43]


 73%|███████▎  | 66/90 [00:35<01:24,  3.53s/it]

[1]


 74%|███████▍  | 67/90 [00:36<00:59,  2.57s/it]

[55]


 76%|███████▌  | 68/90 [00:36<00:41,  1.89s/it]

[7]


 77%|███████▋  | 69/90 [00:36<00:30,  1.44s/it]

[46, 21]


 78%|███████▊  | 70/90 [00:37<00:23,  1.15s/it]

[93, 9]


 79%|███████▉  | 71/90 [00:37<00:17,  1.11it/s]

[67]


 80%|████████  | 72/90 [00:38<00:13,  1.31it/s]

[93, 35]


 81%|████████  | 73/90 [00:38<00:10,  1.56it/s]

[6]


 82%|████████▏ | 74/90 [00:38<00:08,  1.79it/s]

[0]


 83%|████████▎ | 75/90 [00:39<00:07,  2.00it/s]

[65]


 84%|████████▍ | 76/90 [00:39<00:07,  1.87it/s]

[7, 8, 0, 5, 1, 2, 3, 9]


 86%|████████▌ | 77/90 [00:40<00:06,  2.11it/s]

[1]


 87%|████████▋ | 78/90 [00:40<00:05,  2.17it/s]

[13, 70, 27, 6]


 88%|████████▊ | 79/90 [00:40<00:04,  2.44it/s]

[8]


 89%|████████▉ | 80/90 [00:41<00:03,  2.59it/s]

[84]


 90%|█████████ | 81/90 [00:42<00:04,  1.90it/s]

[46, 4]


 91%|█████████ | 82/90 [00:42<00:03,  2.18it/s]

[74]


 92%|█████████▏| 83/90 [00:42<00:03,  2.11it/s]

[59, 49]


 93%|█████████▎| 84/90 [00:43<00:02,  2.30it/s]

[38]


 94%|█████████▍| 85/90 [00:43<00:02,  2.28it/s]

[79, 70, 29, 75]


 96%|█████████▌| 86/90 [00:44<00:01,  2.60it/s]

[27]


 97%|█████████▋| 87/90 [00:44<00:01,  2.81it/s]

[9]


 98%|█████████▊| 88/90 [00:44<00:00,  2.80it/s]

[46]


 99%|█████████▉| 89/90 [00:44<00:00,  3.01it/s]

[32]


100%|██████████| 90/90 [00:45<00:00,  1.99it/s]


[]


  3%|▎         | 1/30 [00:00<00:11,  2.47it/s]

[1, 1]


  7%|▋         | 2/30 [00:00<00:11,  2.51it/s]

[6, 1]


 10%|█         | 3/30 [00:01<00:11,  2.30it/s]

[8, 2, 7, 0, 3]


 13%|█▎        | 4/30 [00:01<00:10,  2.37it/s]

[5, 2]


 17%|█▋        | 5/30 [00:01<00:09,  2.63it/s]

[1, 1]


 20%|██        | 6/30 [00:02<00:08,  2.85it/s]

[5, 2]


 23%|██▎       | 7/30 [00:02<00:08,  2.61it/s]

[8, 2, 7, 0, 3]


 27%|██▋       | 8/30 [00:03<00:07,  2.80it/s]

[5, 2]


 30%|███       | 9/30 [00:04<00:12,  1.64it/s]

To generate the output list, a


 33%|███▎      | 10/30 [00:04<00:11,  1.73it/s]

[8, 2, 7, 0, 3]


 37%|███▋      | 11/30 [00:05<00:14,  1.32it/s]

To generate the output list, y


 40%|████      | 12/30 [00:06<00:11,  1.54it/s]

[5, 2]


 43%|████▎     | 13/30 [00:06<00:10,  1.69it/s]

[8, 2, 7, 0, 3]


 47%|████▋     | 14/30 [00:07<00:08,  1.87it/s]

[5, 2]


 50%|█████     | 15/30 [00:07<00:07,  1.93it/s]

[8, 2, 7, 0, 3]


 53%|█████▎    | 16/30 [00:07<00:06,  2.11it/s]

[5, 2]


 57%|█████▋    | 17/30 [00:08<00:06,  2.06it/s]

[7, 2]


 60%|██████    | 18/30 [00:09<00:05,  2.03it/s]

[8, 2, 7, 0, 3]


 63%|██████▎   | 19/30 [00:09<00:05,  2.00it/s]

[8, 2, 7, 0, 3]


 67%|██████▋   | 20/30 [00:10<00:04,  2.01it/s]

[8, 2, 7, 0, 3]


 70%|███████   | 21/30 [00:10<00:04,  2.10it/s]

[5, 2]


 73%|███████▎  | 22/30 [00:10<00:03,  2.24it/s]

[5, 2]


 77%|███████▋  | 23/30 [00:11<00:04,  1.70it/s]

[3, 5, 6, 4, 6, 7, 5, 1, 6, 7,


 80%|████████  | 24/30 [00:12<00:03,  1.57it/s]

[8, 2, 7, 0, 3]


 83%|████████▎ | 25/30 [00:12<00:02,  1.77it/s]

[5, 2]


 87%|████████▋ | 26/30 [00:13<00:02,  1.76it/s]

[8, 2, 7, 0, 3]


 90%|█████████ | 27/30 [00:13<00:01,  1.83it/s]

[8, 2, 7, 0, 3]


 93%|█████████▎| 28/30 [00:14<00:01,  1.88it/s]

[8, 2, 7, 0, 3]


 97%|█████████▋| 29/30 [00:14<00:00,  2.03it/s]

[5, 2]


100%|██████████| 30/30 [00:15<00:00,  1.96it/s]


[5, 2]


  3%|▎         | 1/30 [00:00<00:15,  1.90it/s]

[4, 1, 5, 9]


  7%|▋         | 2/30 [00:01<00:20,  1.38it/s]

[4, 4, 3, 6, 2, 5, 0, 7, 8, 1,


 10%|█         | 3/30 [00:01<00:16,  1.59it/s]

[8, 8, 9, 1]


 13%|█▎        | 4/30 [00:02<00:17,  1.50it/s]

[0, 7, 1, 3, 4, 5, 9, 2, 6, 9]


 17%|█▋        | 5/30 [00:03<00:17,  1.41it/s]

[2, 5, 3, 0, 6, 8, 1, 9, 9]


 20%|██        | 6/30 [00:03<00:15,  1.57it/s]

[6, 1, 4, 5, 8, 9]


 23%|██▎       | 7/30 [00:04<00:12,  1.80it/s]

[3, 3, 5]


 27%|██▋       | 8/30 [00:04<00:12,  1.72it/s]

[0, 8, 5, 3, 7, 1, 2, 9]


 30%|███       | 9/30 [00:05<00:10,  1.94it/s]

[5, 9]


 33%|███▎      | 10/30 [00:06<00:12,  1.63it/s]

[7, 7, 9, 6, 0, 3, 8, 2, 5, 4,


 37%|███▋      | 11/30 [00:06<00:12,  1.58it/s]

[6, 7, 8, 9, 3, 0, 5, 4, 1]


 40%|████      | 12/30 [00:07<00:09,  1.81it/s]

[3, 9]


 43%|████▎     | 13/30 [00:07<00:08,  1.89it/s]

[7, 7, 0, 1, 3]


 47%|████▋     | 14/30 [00:08<00:08,  1.86it/s]

[0, 3, 7, 2, 4, 9]


 50%|█████     | 15/30 [00:08<00:06,  2.21it/s]

[0]


 53%|█████▎    | 16/30 [00:09<00:07,  1.92it/s]

[0, 6, 1, 9, 2, 8, 5, 5]


 57%|█████▋    | 17/30 [00:09<00:07,  1.64it/s]

[3, 1, 4, 8, 5, 6, 2, 7, 9, 0,


 60%|██████    | 18/30 [00:10<00:06,  1.76it/s]

[0, 0, 9, 3]


 63%|██████▎   | 19/30 [00:11<00:06,  1.72it/s]

[2, 2, 9, 8, 6, 7, 0]


 67%|██████▋   | 20/30 [00:11<00:05,  1.76it/s]

[4, 4, 7, 0, 5]


 70%|███████   | 21/30 [00:11<00:04,  2.06it/s]

[9]


 73%|███████▎  | 22/30 [00:12<00:04,  1.96it/s]

[4, 6, 1, 0, 5, 8, 9]


 77%|███████▋  | 23/30 [00:13<00:04,  1.49it/s]

[6, 6, 7, 0, 4, 2, 5, 9, 1, 8,


 80%|████████  | 24/30 [00:13<00:03,  1.69it/s]

[3, 3]


 83%|████████▎ | 25/30 [00:14<00:03,  1.51it/s]

[1, 4, 8, 2, 3, 5, 9, 0, 7, 6,


 87%|████████▋ | 26/30 [00:15<00:02,  1.64it/s]

[9, 9, 7, 5, 2]


 90%|█████████ | 27/30 [00:15<00:01,  1.73it/s]

[6, 6, 5, 2]


 93%|█████████▎| 28/30 [00:16<00:01,  1.92it/s]

[1]


 97%|█████████▋| 29/30 [00:16<00:00,  1.70it/s]

[0, 5, 6, 3, 2, 1, 4, 7, 8, 9]


100%|██████████| 30/30 [00:17<00:00,  1.71it/s]


[7, 0, 5, 4, 3, 2, 9, 1, 6, 9]


  3%|▎         | 1/30 [00:00<00:23,  1.24it/s]

[8]


  7%|▋         | 2/30 [00:01<00:18,  1.55it/s]

[0, 1, 2, 3]


 10%|█         | 3/30 [00:01<00:13,  1.98it/s]

[91, 92]


 13%|█▎        | 4/30 [00:02<00:11,  2.27it/s]

[64]


 17%|█▋        | 5/30 [00:02<00:09,  2.52it/s]

[15]


 20%|██        | 6/30 [00:02<00:08,  2.70it/s]

[3]


 23%|██▎       | 7/30 [00:03<00:10,  2.10it/s]

[2, 3, 4, 7, 8, 15, 9, 14]


 27%|██▋       | 8/30 [00:03<00:09,  2.44it/s]

[8]


 30%|███       | 9/30 [00:03<00:08,  2.57it/s]

[7]


 33%|███▎      | 10/30 [00:04<00:08,  2.46it/s]

[5, 6, 7]


 37%|███▋      | 11/30 [00:04<00:07,  2.52it/s]

[64, 77]


 40%|████      | 12/30 [00:05<00:06,  2.66it/s]

[8]


 43%|████▎     | 13/30 [00:05<00:08,  1.95it/s]

[4, 5, 6, 7, 8, 9, 10]


 47%|████▋     | 14/30 [00:06<00:07,  2.15it/s]

[2]


 50%|█████     | 15/30 [00:06<00:07,  1.91it/s]

[1, 2, 3, 4, 5, 6, 7]


 53%|█████▎    | 16/30 [00:07<00:06,  2.13it/s]

[6]


 57%|█████▋    | 17/30 [00:07<00:05,  2.32it/s]

[5]


 60%|██████    | 18/30 [00:08<00:06,  1.78it/s]

[57, 58, 59, 60, 61, 62, 63, 6


 63%|██████▎   | 19/30 [00:09<00:07,  1.52it/s]

[82, 83, 84, 85, 86, 87, 88, 8


 67%|██████▋   | 20/30 [00:09<00:06,  1.63it/s]

[3, 4, 5, 6, 7, 8]


 70%|███████   | 21/30 [00:10<00:04,  1.88it/s]

[9]


 73%|███████▎  | 22/30 [00:10<00:04,  1.98it/s]

[7]


 77%|███████▋  | 23/30 [00:10<00:02,  2.34it/s]

[44]


 80%|████████  | 24/30 [00:11<00:02,  2.35it/s]

[0, 1, 2]


 83%|████████▎ | 25/30 [00:11<00:02,  2.48it/s]

[7]


 87%|████████▋ | 26/30 [00:12<00:01,  2.56it/s]

[9, 14, 16]


 90%|█████████ | 27/30 [00:12<00:01,  2.66it/s]

[8, 9]


 93%|█████████▎| 28/30 [00:12<00:00,  2.30it/s]

[0, 1, 2, 3, 4, 5, 6]


 97%|█████████▋| 29/30 [00:13<00:00,  2.44it/s]

[9]


100%|██████████| 30/30 [00:13<00:00,  2.19it/s]


[7]


  3%|▎         | 1/30 [00:00<00:09,  3.17it/s]

'111'


  7%|▋         | 2/30 [00:00<00:08,  3.35it/s]

'0'


 10%|█         | 3/30 [00:01<00:09,  2.86it/s]

'1'


 13%|█▎        | 4/30 [00:01<00:08,  2.94it/s]

'01'


 17%|█▋        | 5/30 [00:01<00:07,  3.33it/s]

1


 20%|██        | 6/30 [00:01<00:07,  3.20it/s]

10110


 23%|██▎       | 7/30 [00:02<00:07,  3.14it/s]

'1'


 27%|██▋       | 8/30 [00:02<00:07,  3.03it/s]

'1111'


 30%|███       | 9/30 [00:02<00:07,  2.83it/s]

'1'


 33%|███▎      | 10/30 [00:03<00:07,  2.52it/s]

'00'


 37%|███▋      | 11/30 [00:03<00:07,  2.65it/s]

'001'


 40%|████      | 12/30 [00:04<00:08,  2.23it/s]

0101


 43%|████▎     | 13/30 [00:04<00:07,  2.43it/s]

'1110'


 47%|████▋     | 14/30 [00:05<00:06,  2.54it/s]

'0'


 50%|█████     | 15/30 [00:05<00:05,  2.65it/s]

1


 53%|█████▎    | 16/30 [00:05<00:05,  2.71it/s]

'00'


 57%|█████▋    | 17/30 [00:06<00:04,  2.94it/s]

'0'


 60%|██████    | 18/30 [00:06<00:04,  2.95it/s]

'10100'


 63%|██████▎   | 19/30 [00:06<00:03,  2.85it/s]

'1'


 67%|██████▋   | 20/30 [00:07<00:03,  3.11it/s]

'1'


 70%|███████   | 21/30 [00:07<00:02,  3.34it/s]

'10'


 73%|███████▎  | 22/30 [00:07<00:02,  3.22it/s]

'111'


 77%|███████▋  | 23/30 [00:07<00:02,  3.41it/s]

'0'


 80%|████████  | 24/30 [00:08<00:01,  3.49it/s]

'111'


 83%|████████▎ | 25/30 [00:08<00:01,  3.70it/s]

00110


 87%|████████▋ | 26/30 [00:08<00:01,  3.86it/s]

'111'


 90%|█████████ | 27/30 [00:08<00:00,  3.86it/s]

'11'


 93%|█████████▎| 28/30 [00:09<00:00,  3.97it/s]

'1'


 97%|█████████▋| 29/30 [00:09<00:00,  3.74it/s]

10110


100%|██████████| 30/30 [00:09<00:00,  3.08it/s]

'00'





# Evaluation using Mistral-7B

In [None]:
mistral_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
mistral_model = AutoModelForCausalLM.from_pretrained(mistral_model_id,
                                             trust_remote_code=True,
                                             torch_dtype=torch.float16,
                                        attn_implementation="flash_attention_2",
                                        quantization_config=BitsAndBytesConfig(
        load_in_8bit=True,
    ),)
eval_tokenizer = AutoTokenizer.from_pretrained(mistral_model_id)

In [None]:
eval_sets = [
    TEH_i_l, TEO_i_l, climate_i_l, SUBJ_i_l,
    TEH_i_s, TEO_i_s, climate_i_s, SUBJ_i_s,
    TEH_n, TEO_n, climate_n, SUBJ_n,
    TEH_f, TEO_f, climate_f, SUBJ_f
]

eval_sets_name = [
    "TEH_i_l", "TEO_i_l", "climate_i_l", "SUBJ_i_l",
    "TEH_i_s", "TEO_i_s", "climate_i_s", "SUBJ_i_s",
    "TEH_n", "TEO_n", "climate_n", "SUBJ_n",
    "TEH_f", "TEO_f", "climate_f", "SUBJ_f"
]


for s, name in zip(eval_sets,eval_sets_name):
    model_eval_m(
    dataset=s,
    num_data=25,
    eval_tokenizer=eval_tokenizer,
    model=mistral_model,
    file_name=name+"_m",
    max_new_tokens=50
)

In [None]:
big_bench_sets = [modify_the_list, remove_elements, input_independent, add_elements, miscellaneous, Turing]
big_bench_sets_name = [
    "modify_the_list", "remove_elements", "input_independent", "add_elements", "miscellaneous", "Turing"
]

for s, name in zip(big_bench_sets, big_bench_sets_name):
    model_eval_m(
    dataset=s,
    num_data=25,
    eval_tokenizer=eval_tokenizer,
    model=mistral_model,
    file_name=name+"_m",
    max_new_tokens=50,
    big_bench=True,
    big_bench_prompt=big_bench_prompt
)

  0%|          | 0/120 [00:00<?, ?it/s]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1%|          | 1/120 [00:08<17:14,  8.69s/it]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  2%|▏         | 2/120 [00:16<15:55,  8.09s/it]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  2%|▎         | 3/120 [00:22<14:17,  7.33s/it]The attention mask and the pad token id were not set. As a consequence, you may

# Evaluation using phi-2

In [None]:
base_model_id = "microsoft/phi-2"
symbol_model_id = "ChunB1/phi-2_OpenOrca_cot_only_symbol"
eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False)
eval_tokenizer.pad_token = eval_tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
symbol_model = AutoModelForCausalLM.from_pretrained(symbol_model_id,
                                             trust_remote_code=True,
                                             torch_dtype=torch.float16,
                                        attn_implementation="flash_attention_2",
                                        quantization_config=BitsAndBytesConfig(
        load_in_8bit=True,
    ),)

In [None]:
symbol_model

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiFlashAttention2(
          (q_proj): lora.Linear8bitLt(
            (base_layer): Linear8bitLt(in_features=2560, out_features=2560, bias=True)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2560, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=2560, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
          )
          (k_proj): lora.Linear8bitLt(
            (base_layer): Linear8bitLt(in_features=2560, out_features=2560, bias=True)
            (lora_dropout): ModuleDict(
  

In [None]:
model_eval(
    dataset=TEH_i_s,
    num_data=25,
    eval_tokenizer=eval_tokenizer,
    symbol_model=symbol_model,
    file_name="TEH_i_s",
    max_new_tokens=50
)
model_eval(
    dataset=TEO_i_s,
    num_data=25,
    eval_tokenizer=eval_tokenizer,
    symbol_model=symbol_model,
    file_name="TEO_i_s",
    max_new_tokens=50
)
model_eval(
    dataset=climate_i_s,
    num_data=25,
    eval_tokenizer=eval_tokenizer,
    symbol_model=symbol_model,
    file_name="climate_i_s",
    max_new_tokens=50
)
model_eval(
    dataset=SUBJ_i_s,
    num_data=25,
    eval_tokenizer=eval_tokenizer,
    symbol_model=symbol_model,
    file_name="SUBJ_i_s",
    max_new_tokens=50
)

  0%|          | 0/25 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  4%|▍         | 1/25 [00:08<03:29,  8.73s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  8%|▊         | 2/25 [00:15<03:00,  7.85s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 12%|█▏        | 3/25 [00:23<02:46,  7.59s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 16%|█▌        | 4/25 [00:30<02:36,  7.45s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 20%|██        | 5/25 [00:37<02:27,  7.38s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 24%|██▍       | 6/25 [00:45<02:20,  7.39s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 28%|██▊       | 7/25 [00:52<02:13,  7.43s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 32%|███▏      | 8/25 [01:00<02:06,  7.44s/it]Setting `pad_token

Unnamed: 0,Input,Target,Symbol Model Generation
0,Question: Is the following sentence subjective...,FOPZ,Output: Yenix\nYenix: The following sentence i...
1,Question: Is the following sentence subjective...,Hibaz,Output: NMQA\nNMQA: The following sentence is ...
2,Question: Is the following sentence subjective...,Nahey,Output: Nahey\nHJZZ: Nahey\nNahey: Nahey\nNahe...
3,Question: Is the following sentence subjective...,KURB,Output: KURB\nQuestion: Is the following sente...
4,Question: Is the following sentence subjective...,34030,Output: 34030\nQuestion: Is the following sent...
5,Question: Is the following sentence subjective...,IMOD,Output: IMOD\nQuestion: Is the following sente...
6,Question: Is the following sentence subjective...,JODV,Output: UETD\nAnswer:\n\nUETD\nUETD\nUETD\nUET...
7,Question: Is the following sentence subjective...,47878,Output: KHHF\nQuestion: Is the following sente...
8,Question: Is the following sentence subjective...,QJQU,Output: YZQS\nQuestion: Is the following sente...
9,Question: Is the following sentence subjective...,MYZE,Output: MYZE\nAnswer:\n\nKagus\nKagus\nMYZE\nM...
