In [1]:
# uses autoawq==0.2.4 autoawq_kernels==0.0.6
"""
model - 7B model - https://huggingface.co/TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ
dataset - gsm8k train split, llmsearch is run on a subset and evaluated on another


Requires:
autoawq==0.2.4 & autoawq_kernels==0.0.6
"""

import awq
import torch
import transformers

import llmsearch

print(awq.__version__, torch.__version__, transformers.__version__, llmsearch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


Monkey Patching .generate function of `transformers` library
0.2.4 2.2.0+cu121 4.38.2 0.1.0


In [None]:
import re
import textwrap
from pathlib import Path

import datasets

from awq import AutoAWQForCausalLM
from sklearn.model_selection import GridSearchCV
from transformers import StoppingCriteriaList, AutoTokenizer

from llmsearch.tuner import Tuner
from llmsearch.utils.mem_utils import gc_cuda
from llmsearch.utils.common_utils import json_load, json_dump
from llmsearch.utils.model_downloader import download_model_from_hf
from llmsearch.scripts.stopping_criteria import MultiTokenStoppingCriteria

In [3]:
seed = 42
batch_size = 1
num_tune_samples = 150
num_test_samples = 500
model_id = "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ"
device = "cuda:0"

In [8]:
def load_model_and_tokenizer(model_id, temp_model_dir):
    temp_model_dir.mkdir(exist_ok=True, parents=True)
    output_folder = download_model_from_hf(model_id, save_dir=temp_model_dir, branch="main")

    gc_cuda()

    model = AutoAWQForCausalLM.from_quantized(
        quant_path=output_folder, fuse_layers=True, device_map={"": device}, local_files_only=True
    )

    tokenizer = AutoTokenizer.from_pretrained(
        output_folder, local_files_only=True, legacy=False, use_fast=False
    )
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    return model, tokenizer

def load_dataset():

    def preprocess_dataset(
        dataset, tokenizer, pt, pt_cols, system_prompt, add_generation_prompt=True
    ):

        def wrapper(sample):
            """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
            messages = (
                []
                if system_prompt is None
                else [{"role": "system", "content": system_prompt}]
            )
            formatted_pt = pt.format(**{pt_col: sample[pt_col] for pt_col in pt_cols})
            messages.append(
                {
                    "role": "user",
                    "content": formatted_pt,
                }
            )
            formatted_pt_with_ct = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=add_generation_prompt
            )
            return formatted_pt_with_ct

        def actual_input(sample):
            """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
            return sample[pt_cols[0]]

        pt_dataset = dataset.map(
            lambda sample: {
                "X": wrapper(sample),
                "actual input": actual_input(sample),
            }
        )

        return pt_dataset


    # 2-shot prompt template - https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot.yaml
    pt = textwrap.dedent(
    """\
    Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
    A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

    Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
    A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

    Q: {question}"""
    )
    pt_cols = ["question"]
    system_prompt = "Solve the following math problems, end with The answer is"
    gsm8k_dataset = datasets.load_dataset("gsm8k", "main")


    processed_dataset = preprocess_dataset(
        gsm8k_dataset["train"],
        tokenizer,
        pt=pt,
        pt_cols=pt_cols,
        system_prompt=system_prompt,
        add_generation_prompt=True,
    )

    shuffled_dataset = processed_dataset.shuffle(seed=seed)

    samples_to_tune_on = shuffled_dataset.select(range(num_tune_samples))
    remaining_indices = range(num_tune_samples, num_tune_samples + num_test_samples)
    test_dataset = shuffled_dataset.select(remaining_indices)
    return samples_to_tune_on, test_dataset

def get_score(y_true, y_pred):
    def standardize(s):
        if s is None:
            return s
        s = s.replace(",", "")
        if s.endswith("."):
            s = s[:-1]
        return s.strip()

    def extract_answer_from_out(s):
        pattern = re.compile(r"The answer is ((\d|\-)((\d|\,|\.)+)?\d?)")
        match = pattern.search(s)
        if match:
            return match.group(1).strip()
        else:
            return None

    scores = []

    for y_t, y_p in zip(y_true, y_pred):
        y_t_answer = y_t["answer"].split("####")[-1].strip()
        y_p_answer = extract_answer_from_out(y_p)

        y_t_answer = standardize(y_t_answer)
        y_p_answer = standardize(y_p_answer)

        # print("y_pred - ", y_p_answer)
        # print("y_true - ", y_t_answer)

        if y_t_answer == y_p_answer:
            scores.append(1)
        else:
            scores.append(0)
    return sum(scores) / len(scores)


In [5]:
# Load Model, Tokenizer, Dataset
temp_model_dir = Path(f"./temp_dir/")
temp_model_dir.mkdir(exist_ok=True, parents=True)

model, tokenizer = load_model_and_tokenizer(model_id, temp_model_dir)

# Dataset we will use to find the best generation parameters
samples_to_tune_on,test_dataset = load_dataset()

multi_token_stop_criteria_ob = MultiTokenStoppingCriteria(sequence_ids=[32000])
stopping_criteria = StoppingCriteriaList([multi_token_stop_criteria_ob])
callbacks_after_inference = [multi_token_stop_criteria_ob.reset]

Model already exists in temp_dir/TheBloke_CapybaraHermes-2.5-Mistral-7B-AWQ. Checking the model files...
Checksum validated: model.safetensors  645dfc7f09074aaf25e642f3c6a4f7ea399a0ff2605fa650e4e74078832546de
Checksum validated: tokenizer.model  dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
[+] Validated checksums of all model files!


Replacing layers...: 100%|██████████| 32/32 [00:03<00:00,  9.06it/s]
Fusing layers...: 100%|██████████| 32/32 [00:03<00:00,  9.62it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map: 100%|██████████| 7473/7473 [00:00<00:00, 11001.58 examples/s]


In [9]:
tuner_ob = Tuner(
    model=model,
    tokenizer=tokenizer,
    dataset=samples_to_tune_on,
    device="cuda:0",
    batch_size=batch_size,
    tokenizer_encode_args={"padding": "longest", "add_special_tokens": False},
    tokenizer_decode_args={"spaces_between_special_tokens": False, 'skip_special_tokens' : True},
    scorer=get_score,
    prompt_template="{X}",
    seed=seed,
    column_mapping={"input_cols": ["X"], "eval_cols": ["answer"]},
    callbacks_after_inference=callbacks_after_inference,
)

In [10]:
gen_params1 = {
    "max_new_tokens": 500,
    "stopping_criteria": stopping_criteria,
    "generation_seed": 42,
}

scores_before, outputs_before = tuner_ob.get_score(gen_params1)

100%|██████████| 150/150 [05:59<00:00,  2.39s/it]


In [11]:
hyp_space = {
    'max_new_tokens' : [500],
    'stopping_criteria' : [stopping_criteria],
    'generation_seed' : [42],
    'do_sample' : [True],

    'top_k': [10,50,60,70,80],
    'top_p' : [0.7,0.75,0.8,0.95],
    'no_repeat_ngram_size': [0],

}

clf = GridSearchCV(
    estimator = tuner_ob.estimator,
    param_grid=hyp_space,
    scoring = tuner_ob.scorer,
    cv = 2,
    n_jobs = None,
    verbose=3,
)


In [13]:
clf.fit(X=tuner_ob.dataset["_X"], y=tuner_ob.dataset['_y'])

Fitting 2 folds for each of 20 candidates, totalling 40 fits


  0%|          | 0/75 [00:00<?, ?it/s]

100%|██████████| 75/75 [03:12<00:00,  2.56s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.7;, score=0.707 total time= 3.2min


100%|██████████| 75/75 [02:51<00:00,  2.29s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.7;, score=0.693 total time= 2.9min


100%|██████████| 75/75 [03:09<00:00,  2.52s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.75;, score=0.667 total time= 3.2min


100%|██████████| 75/75 [02:50<00:00,  2.27s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.75;, score=0.747 total time= 2.8min


100%|██████████| 75/75 [03:03<00:00,  2.45s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.8;, score=0.733 total time= 3.1min


100%|██████████| 75/75 [02:51<00:00,  2.29s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.8;, score=0.707 total time= 2.9min


100%|██████████| 75/75 [03:10<00:00,  2.54s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.95;, score=0.747 total time= 3.2min


100%|██████████| 75/75 [02:54<00:00,  2.32s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=10, top_p=0.95;, score=0.680 total time= 2.9min


100%|██████████| 75/75 [03:08<00:00,  2.51s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.7;, score=0.720 total time= 3.1min


100%|██████████| 75/75 [02:52<00:00,  2.30s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.7;, score=0.693 total time= 2.9min


100%|██████████| 75/75 [03:09<00:00,  2.53s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.75;, score=0.640 total time= 3.2min


100%|██████████| 75/75 [02:54<00:00,  2.33s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.75;, score=0.720 total time= 2.9min


100%|██████████| 75/75 [03:06<00:00,  2.48s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.8;, score=0.680 total time= 3.1min


100%|██████████| 75/75 [02:47<00:00,  2.23s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.8;, score=0.733 total time= 2.8min


100%|██████████| 75/75 [03:10<00:00,  2.54s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.95;, score=0.720 total time= 3.2min


100%|██████████| 75/75 [02:51<00:00,  2.29s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=50, top_p=0.95;, score=0.653 total time= 2.9min


100%|██████████| 75/75 [03:08<00:00,  2.51s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.7;, score=0.720 total time= 3.1min


100%|██████████| 75/75 [02:51<00:00,  2.29s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.7;, score=0.693 total time= 2.9min


100%|██████████| 75/75 [03:08<00:00,  2.52s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.75;, score=0.640 total time= 3.2min


100%|██████████| 75/75 [02:54<00:00,  2.32s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.75;, score=0.720 total time= 2.9min


100%|██████████| 75/75 [03:05<00:00,  2.48s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.8;, score=0.680 total time= 3.1min


100%|██████████| 75/75 [02:46<00:00,  2.22s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.8;, score=0.733 total time= 2.8min


100%|██████████| 75/75 [03:10<00:00,  2.54s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.95;, score=0.720 total time= 3.2min


100%|██████████| 75/75 [02:51<00:00,  2.28s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=60, top_p=0.95;, score=0.653 total time= 2.9min


100%|██████████| 75/75 [03:08<00:00,  2.51s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.7;, score=0.720 total time= 3.1min


100%|██████████| 75/75 [02:51<00:00,  2.29s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.7;, score=0.693 total time= 2.9min


100%|██████████| 75/75 [03:08<00:00,  2.52s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.75;, score=0.640 total time= 3.1min


100%|██████████| 75/75 [02:53<00:00,  2.32s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.75;, score=0.720 total time= 2.9min


100%|██████████| 75/75 [03:05<00:00,  2.47s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.8;, score=0.680 total time= 3.1min


100%|██████████| 75/75 [02:45<00:00,  2.21s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.8;, score=0.733 total time= 2.8min


100%|██████████| 75/75 [03:10<00:00,  2.54s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.95;, score=0.720 total time= 3.2min


100%|██████████| 75/75 [02:51<00:00,  2.29s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=70, top_p=0.95;, score=0.653 total time= 2.9min


100%|██████████| 75/75 [03:08<00:00,  2.51s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.7;, score=0.720 total time= 3.1min


100%|██████████| 75/75 [02:53<00:00,  2.31s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.7;, score=0.693 total time= 2.9min


100%|██████████| 75/75 [03:09<00:00,  2.52s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.75;, score=0.640 total time= 3.2min


100%|██████████| 75/75 [02:53<00:00,  2.32s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.75;, score=0.720 total time= 2.9min


100%|██████████| 75/75 [03:05<00:00,  2.48s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.8;, score=0.680 total time= 3.1min


100%|██████████| 75/75 [02:46<00:00,  2.22s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.8;, score=0.733 total time= 2.8min


100%|██████████| 75/75 [03:09<00:00,  2.53s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.95;, score=0.720 total time= 3.2min


100%|██████████| 75/75 [02:52<00:00,  2.30s/it]

[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], top_k=80, top_p=0.95;, score=0.653 total time= 2.9min





In [14]:
scores_after, outputs_after = tuner_ob.get_score(clf.best_params_)

 10%|█         | 15/150 [00:43<05:49,  2.59s/it]

100%|██████████| 150/150 [05:57<00:00,  2.38s/it]


In [15]:
print(scores_before, scores_after)

0.72 0.7266666666666667


In [16]:
str(clf.best_params_)

"{'do_sample': True, 'generation_seed': 42, 'max_new_tokens': 500, 'no_repeat_ngram_size': 0, 'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], 'top_k': 10, 'top_p': 0.8}"

In [17]:

d = {
    'scores_before' : scores_before,
    'scores_after' : scores_after,
    'outputs_before' : outputs_before,
    'outputs_after' : outputs_after,
    'best_params' : str(clf.best_params_),
}

f = "./gsm-8k-best-params-150s-capybara-7b.json"
json_dump(d, f)

In [18]:
d = json_load("./gsm-8k-best-params-150s-capybara-7b.json")

In [20]:
print(d['best_params'])

{'do_sample': True, 'generation_seed': 42, 'max_new_tokens': 500, 'no_repeat_ngram_size': 0, 'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f8f9e357c40>], 'top_k': 10, 'top_p': 0.8}


In [21]:
# harcoding from above file here due to notebook re-run

best_params = {
    'do_sample' : True,
    'generation_seed' : 42,
    'max_new_tokens' : 500,
    'no_repeat_ngram_size' : 0,
    'stopping_criteria' : stopping_criteria,
    'top_k' : 10,
    'top_p' : 0.8
}

In [24]:
# eval on test samples

gen_params1 = {
    "max_new_tokens": 500,
    "stopping_criteria": stopping_criteria,
    "generation_seed": 42,
}

oos_scores_before, oos_outputs_before = tuner_ob.get_score(gen_params1,test_dataset)

Map: 100%|██████████| 500/500 [00:00<00:00, 9677.50 examples/s]
100%|██████████| 500/500 [20:25<00:00,  2.45s/it]


In [25]:
oos_scores_before

0.564

In [26]:
test_dataset[1]

{'question': 'John climbs up 9 flights of stairs.  Each flight is 10 feet.  If each step is 18 inches, how many steps does he climb up?',
 'answer': 'He has to climb 9*10=<<9*10=90>>90 feet\nThat means he needs to climb 90*12=<<90*12=1080>>1080 inches\nThat means he needs to climb 1080/18=<<1080/18=60>>60 stairs\n#### 60',
 'X': '<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\nQ: John climbs up 9 flights of stairs.  Each

In [27]:

oos_scores_after, oos_outputs_after = tuner_ob.get_score(clf.best_params_,test_dataset)

Map: 100%|██████████| 500/500 [00:00<00:00, 9790.44 examples/s]
100%|██████████| 500/500 [21:14<00:00,  2.55s/it]


In [28]:
print(oos_scores_before, oos_scores_after)

0.564 0.584


In [29]:
d = {
    'scores_before' : scores_before,
    'scores_after' : scores_after,
    'outputs_before' : outputs_before,
    'outputs_after' : outputs_after,

    'oos_scores_before' : oos_scores_before,
    'oos_scores_after' : oos_scores_after,
    'oos_outputs_before' : oos_outputs_before,
    'oos_outputs_after' : oos_outputs_after,
    'best_params' : str(clf.best_params_),
}

f = "./gsm-8k-best-params-150s-capybara-7b.json"
json_dump(d, f)