In [3]:
import awq
import torch
import transformers

import llmsearch

print(awq.__version__, torch.__version__, transformers.__version__, llmsearch.__version__)

0.2.4 2.2.0+cu121 4.38.2 0.1.0


In [11]:
import re
import textwrap
from pathlib import Path

import datasets

from awq import AutoAWQForCausalLM
from sklearn.model_selection import GridSearchCV
from transformers import StoppingCriteriaList, AutoTokenizer

from llmsearch.tuner import Tuner
from llmsearch.utils.mem_utils import gc_cuda
from llmsearch.utils.model_downloader import download_model_from_hf
from llmsearch.scripts.stopping_criteria import MultiTokenStoppingCriteria

In [5]:
seed = 42
batch_size = 1
bm_sample_size = 150
model_id = "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ"
device = "cuda:0"

In [6]:
def load_model_and_tokenizer(model_id, temp_model_dir):
    temp_model_dir.mkdir(exist_ok=True, parents=True)
    output_folder = download_model_from_hf(model_id, save_dir=temp_model_dir, branch="main")

    gc_cuda()

    model = AutoAWQForCausalLM.from_quantized(
        quant_path=output_folder, fuse_layers=True, device_map={"": device}, local_files_only=True
    )

    tokenizer = AutoTokenizer.from_pretrained(
        output_folder, local_files_only=True, legacy=False, use_fast=False
    )
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    return model, tokenizer

def load_dataset():

    def preprocess_dataset(
        dataset, tokenizer, pt, pt_cols, system_prompt, add_generation_prompt=True
    ):

        def wrapper(sample):
            """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
            messages = (
                []
                if system_prompt is None
                else [{"role": "system", "content": system_prompt}]
            )
            formatted_pt = pt.format(**{pt_col: sample[pt_col] for pt_col in pt_cols})
            messages.append(
                {
                    "role": "user",
                    "content": formatted_pt,
                }
            )
            formatted_pt_with_ct = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=add_generation_prompt
            )
            return formatted_pt_with_ct

        def actual_input(sample):
            """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
            return sample[pt_cols[0]]

        pt_dataset = dataset.map(
            lambda sample: {
                "X": wrapper(sample),
                "actual input": actual_input(sample),
            }
        )

        return pt_dataset


    # 2-shot prompt template - https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot.yaml
    pt = textwrap.dedent(
    """\
    Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
    A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

    Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
    A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

    Q: {question}"""
    )
    pt_cols = ["question"]
    system_prompt = "Solve the following math problems, end with The answer is"
    gsm8k_dataset = datasets.load_dataset("gsm8k", "main")


    processed_dataset = preprocess_dataset(
        gsm8k_dataset["train"],
        tokenizer,
        pt=pt,
        pt_cols=pt_cols,
        system_prompt=system_prompt,
        add_generation_prompt=True,
    )
    bm_samples = processed_dataset.shuffle(seed=seed).select(range(bm_sample_size))
    return bm_samples

def get_score(y_true, y_pred):
    def extract_answer_from_out(s):
        pattern = re.compile(r"The answer is (\d+(?:\.\d+)?)")
        match = pattern.search(s)
        if match:
            return match.group(1).strip()
        else:
            return None

    scores = []

    for y_t, y_p in zip(y_true, y_pred):
        y_t_answer = y_t["answer"].split("####")[-1].strip()
        y_p_answer = extract_answer_from_out(y_p)

        # print("y_pred - ", y_p_answer)
        # print("y_true - ", y_t_answer)

        if y_t_answer == y_p_answer:
            scores.append(1)
        else:
            scores.append(0)
    return sum(scores) / len(scores)


In [7]:
# Load Model, Tokenizer, Dataset
temp_model_dir = Path(f"./temp_dir/")
temp_model_dir.mkdir(exist_ok=True, parents=True)

model, tokenizer = load_model_and_tokenizer(model_id, temp_model_dir)

# Dataset we will use to find the best generation parameters
bm_samples = load_dataset()

multi_token_stop_criteria_ob = MultiTokenStoppingCriteria(sequence_ids=[32000])
stopping_criteria = StoppingCriteriaList([multi_token_stop_criteria_ob])
callbacks_after_inference = [multi_token_stop_criteria_ob.reset]

Downloading the model to temp_dir/TheBloke_CapybaraHermes-2.5-Mistral-7B-AWQ


100%|██████████| 17.9k /17.9k  41.7MiB/s
100%|██████████| 115   /115    279kiB/s
100%|██████████| 51.0  /51.0   219kiB/s
100%|██████████| 911   /911    3.73MiB/s
100%|██████████| 126   /126    284kiB/s
100%|██████████| 420   /420    678kiB/s
100%|██████████| 1.80M /1.80M  18.7MiB/s
100%|██████████| 1.60k /1.60k  6.10MiB/s
100%|██████████| 493k  /493k   12.7MiB/s
100%|██████████| 4.15G /4.15G  373MiB/s
Replacing layers...: 100%|██████████| 32/32 [00:03<00:00,  8.86it/s]
Fusing layers...: 100%|██████████| 32/32 [00:02<00:00, 13.38it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map: 100%|██████████| 7473/7473 [00:00<00:00, 11066.82 examples/s]


In [8]:
tuner_ob = Tuner(
    model=model,
    tokenizer=tokenizer,
    dataset=bm_samples,
    device="cuda:0",
    batch_size=batch_size,
    tokenizer_encode_args={"padding": "longest", "add_special_tokens": False},
    tokenizer_decode_args={"spaces_between_special_tokens": False},
    scorer=get_score,
    prompt_template="{X}",
    seed=seed,
    column_mapping={"input_cols": ["X"], "eval_cols": ["answer"]},
    callbacks_after_inference=callbacks_after_inference,
)

Map: 100%|██████████| 50/50 [00:00<00:00, 3671.61 examples/s]


In [9]:
gen_params1 = {
    "max_new_tokens": 500,
    "stopping_criteria": stopping_criteria,
    "generation_seed": 42,
}

scores_before, outputs_before = tuner_ob.get_score(gen_params1)

100%|██████████| 25/25 [01:26<00:00,  3.47s/it]


In [13]:
hyp_space = {
    'max_new_tokens' : [500],
    'stopping_criteria' : [stopping_criteria],
    'generation_seed' : [42],
    'do_sample' : [True],

    'top_k': [10,50,60,70,80],
    'top_p' : [0.7,0.75,0.8,0.95],
    'no_repeat_ngram_size': [0],

}

clf = GridSearchCV(
    estimator = tuner_ob.estimator,
    param_grid=hyp_space,
    scoring = tuner_ob.scorer,
    cv = 2,
    n_jobs = None,
    verbose=3,
)

In [14]:
clf.fit(X=tuner_ob.dataset["_X"], y=tuner_ob.dataset['_y'])

Fitting 2 folds for each of 20 candidates, totalling 40 fits


100%|██████████| 13/13 [00:49<00:00,  3.83s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.7;, score=0.560 total time=  50.0s


100%|██████████| 13/13 [00:39<00:00,  3.06s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.7;, score=0.680 total time=  39.9s


100%|██████████| 13/13 [00:50<00:00,  3.89s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.75;, score=0.640 total time=  50.7s


100%|██████████| 13/13 [00:41<00:00,  3.23s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.75;, score=0.720 total time=  42.1s


100%|██████████| 13/13 [00:53<00:00,  4.15s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.8;, score=0.680 total time=  54.0s


100%|██████████| 13/13 [00:38<00:00,  3.00s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.8;, score=0.680 total time=  39.1s


100%|██████████| 13/13 [00:53<00:00,  4.09s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.95;, score=0.560 total time=  53.3s


100%|██████████| 13/13 [00:37<00:00,  2.89s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=10, top_p=0.95;, score=0.720 total time=  37.7s


100%|██████████| 13/13 [00:49<00:00,  3.82s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.7;, score=0.640 total time=  49.8s


100%|██████████| 13/13 [00:39<00:00,  3.05s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.7;, score=0.760 total time=  39.8s


100%|██████████| 13/13 [00:52<00:00,  4.00s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.75;, score=0.600 total time=  52.2s


100%|██████████| 13/13 [00:41<00:00,  3.23s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.75;, score=0.720 total time=  42.1s


100%|██████████| 13/13 [00:54<00:00,  4.19s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.8;, score=0.680 total time=  54.7s


100%|██████████| 13/13 [00:38<00:00,  2.99s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.8;, score=0.680 total time=  39.1s


100%|██████████| 13/13 [00:52<00:00,  4.02s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=500, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f24e4523730>], temperature=1.0, top_k=50, top_p=0.95;, score=0.560 total time=  52.5s


  8%|▊         | 1/13 [00:03<00:37,  3.08s/it]


KeyboardInterrupt: 

In [None]:
scores_after, outputs_after = tuner_ob.get_score(clf.best_params_)

In [None]:
print(scores_before, scores_after)

In [None]:
len(bm_samples) * scores_before, len(bm_samples) * scores_after

In [None]:
str(clf.best_params_)

In [None]:
import json

def json_dump(ob : dict, file_path: Path):
    with open(file_path, 'w', encoding="utf-8") as json_file:
        json.dump(ob, json_file, indent=4)

d = {
    'scores_before' : scores_before,
    'scores_after' : scores_after,
    'outputs_before' : outputs_before,
    'outputs_after' : outputs_after,
    'best_params' : str(clf.best_params_),
}

f = "./gsm-8k-best-params-100s-capybara-7b.json"
json_dump(d, f)