In [17]:
"""
Script to test functionality of llmsearch
Works with /archive/runpod_dev_env_setup.sh

Does not do any kind of exhaustive search

Requires - pip install autoawq==0.2.4 autoawq_kernels==0.0.6
"""

import re
import textwrap
from pathlib import Path

import sys
sys.path.append("/workspace/llmsearch/")

import datasets
from sklearn.metrics import make_scorer
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV

from awq import AutoAWQForCausalLM
from transformers import StoppingCriteriaList, AutoTokenizer

from llmsearch.tuner import Tuner
from llmsearch.utils.mem_utils import gc_cuda
from llmsearch.utils.logging_utils import set_verbosity_debug, set_verbosity_info
from llmsearch.utils.model_downloader import download_model_from_hf
from llmsearch.scripts.stopping_criteria import MultiTokenStoppingCriteria

set_verbosity_debug()

def print_lines(n = 2):
    print("\n" * n)


seed = 42
batch_size = 1
bm_sample_size = 5
model_id = "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ"
device = "cuda:0"

def load_model_and_tokenizer(model_id, temp_model_dir):
    temp_model_dir.mkdir(exist_ok=True, parents=True)
    output_folder = download_model_from_hf(model_id, save_dir=temp_model_dir, branch="main")

    gc_cuda()

    model = AutoAWQForCausalLM.from_quantized(
        quant_path=output_folder, fuse_layers=True, device_map={"": device}, local_files_only=True
    )

    tokenizer = AutoTokenizer.from_pretrained(
        output_folder, local_files_only=True, legacy=False, use_fast=False
    )
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    return model, tokenizer

def load_dataset():

    def preprocess_dataset(
        dataset, tokenizer, pt, pt_cols, system_prompt, add_generation_prompt=True
    ):

        def wrapper(sample):
            """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
            messages = (
                []
                if system_prompt is None
                else [{"role": "system", "content": system_prompt}]
            )
            formatted_pt = pt.format(**{pt_col: sample[pt_col] for pt_col in pt_cols})
            messages.append(
                {
                    "role": "user",
                    "content": formatted_pt,
                }
            )
            formatted_pt_with_ct = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=add_generation_prompt
            )
            return formatted_pt_with_ct

        def actual_input(sample):
            """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
            return sample[pt_cols[0]]

        pt_dataset = dataset.map(
            lambda sample: {
                "X": wrapper(sample),
                "actual input": actual_input(sample),
            }
        )

        return pt_dataset


    # 2-shot prompt template - https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot.yaml
    pt = textwrap.dedent(
    """\
    Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
    A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

    Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
    A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

    Q: {question}"""
    )
    pt_cols = ["question"]
    system_prompt = "Solve the following math problems, end with The answer is"
    gsm8k_dataset = datasets.load_dataset("gsm8k", "main")


    processed_dataset = preprocess_dataset(
        gsm8k_dataset["train"],
        tokenizer,
        pt=pt,
        pt_cols=pt_cols,
        system_prompt=system_prompt,
        add_generation_prompt=True,
    )
    bm_samples = processed_dataset.shuffle(seed=seed).select(range(bm_sample_size))
    return bm_samples

def get_score(y_true, y_pred):
    def extract_answer_from_out(s):
        pattern = re.compile(r"The answer is (\d+(?:\.\d+)?)")
        match = pattern.search(s)
        if match:
            return match.group(1).strip()
        else:
            return None

    scores = []

    for y_t, y_p in zip(y_true, y_pred):
        y_t_answer = y_t["answer"].split("####")[-1].strip()
        y_p_answer = extract_answer_from_out(y_p)

        print("y_pred - ", y_p_answer)
        print("y_true - ", y_t_answer)

        if y_t_answer == y_p_answer:
            scores.append(1)
        else:
            scores.append(0)
    return sum(scores) / len(scores)

# Load Model, Tokenizer, Dataset
temp_model_dir = Path(f"./temp_dir/")
temp_model_dir.mkdir(exist_ok=True, parents=True)

model, tokenizer = load_model_and_tokenizer(model_id, temp_model_dir)

bm_samples = load_dataset()

multi_token_stop_criteria_ob = MultiTokenStoppingCriteria(sequence_ids=[32000])
stopping_criteria = StoppingCriteriaList([multi_token_stop_criteria_ob])
callbacks_after_inference = [multi_token_stop_criteria_ob.reset]
print(f"During init - {callbacks_after_inference}") # 0x7fa49e1d9c90, 0x7fa49049f0a0

Downloading the model to temp_dir/TheBloke_CapybaraHermes-2.5-Mistral-7B-AWQ


100%|██████████| 51.0  /51.0   230kiB/s
100%|██████████| 911   /911    270kiB/s
100%|██████████| 17.9k /17.9k  7.94MiB/s
100%|██████████| 115   /115    499kiB/s
100%|██████████| 420   /420    264kiB/s
  1%|          | 21.0M /4.15G  191MiB/s
100%|██████████| 126   /126    678kiB/s
100%|██████████| 1.80M /1.80M  3.67MiB/s
100%|██████████| 1.60k /1.60k  5.37MiB/s
100%|██████████| 493k  /493k   11.7MiB/s
100%|██████████| 4.15G /4.15G  260MiB/s


KeyboardInterrupt: 

In [2]:
tuner_ob = Tuner(
    model=model,
    tokenizer=tokenizer,
    dataset=bm_samples,
    device="cuda:0",
    batch_size=batch_size,
    tokenizer_encode_args={"padding": "longest", "add_special_tokens": False},
    tokenizer_decode_args={"spaces_between_special_tokens": False},
    scorer=get_score,
    prompt_template="{X}",
    seed=seed,
    column_mapping={"input_cols": ["X"], "eval_cols": ["answer"]},
    callbacks_after_inference=callbacks_after_inference,
)

2024-04-11 12:42:10.061 - llmsearch.tuner.tuner:89 - DEBUG - Initializing new estimator with generation parameters - {}


During estimator init - [<bound method MultiTokenStoppingCriteria.reset of <llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f7c826d2d40>>]


In [None]:
stopping_c

In [5]:
# 0x7f7c826d2d40

tuner_ob.estimator.callbacks_after_inference

[<bound method MultiTokenStoppingCriteria.reset of <llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f7c826d2d40>>]

In [None]:
tuner_ob.estimator

In [4]:
from sklearn.base import clone as sklearn_clone

In [8]:
gen_params = tuner_ob.estimator.get_params()

filtered_gen_params = {k: v for k, v in gen_params.items() if k in tuner_ob.estimator._model_generation_param_keys}

In get_params
<module> - gen_params = tuner_ob.estimator.get_params()


In [21]:
type(stopping_criteria)

transformers.generation.stopping_criteria.StoppingCriteriaList

In [12]:
params = {'stopping_criteria' : stopping_criteria}

In [13]:
params

{'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria at 0x7f7c826d2d40>]}

In [16]:
# different 0x7f7c70312f50 !! than 0x7f7c826d2d40
# params are deepcopied - https://github.com/scikit-learn/scikit-learn/blob/10b5c6628630d57e3025a17381f71172d2649d8d/sklearn/base.py#L104
# so a gen param that maintains state will not work as expected
# solution TODO
sklearn_clone(params, safe = False)

{'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria at 0x7f7c70312f50>]}

In [10]:
hyp_space = {
    "max_new_tokens": [500],
    "stopping_criteria": [stopping_criteria],
    "generation_seed": [42],
    "do_sample": [True],
    'temperature' : [0.9],
    'top_k' : [50],
    'top_p' : [0.95],
}

clf = GridSearchCV(
    estimator=tuner_ob.estimator,
    param_grid=hyp_space,
    scoring=tuner_ob.scorer,
    cv=2,
    n_jobs=None,
    verbose=3,
)

In [11]:
hyp_space

{'max_new_tokens': [500],
 'stopping_criteria': [[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria at 0x7f7c826d2d40>]],
 'generation_seed': [42],
 'do_sample': [True],
 'temperature': [0.9],
 'top_k': [50],
 'top_p': [0.95]}

In [7]:
gen_params['callbacks_after_inference']

[<bound method MultiTokenStoppingCriteria.reset of <llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f7c826d2d40>>]

In [7]:
new_estimator = tuner_ob.estimator.set_params(**sklearn_clone(tuner_ob.estimator.get_params(), safe=False))
new_estimator.callbacks_after_inference

2024-04-11 12:33:35.736 - llmsearch.tuner.tuner:206 - DEBUG - Attributes after setting new parameters - {'model': MistralAWQForCausalLM(
  (model): MistralForCausalLM(
    (model): LlamaLikeModel(
      (embedding): Embedding(32002, 4096, padding_idx=0)
      (blocks): ModuleList(
        (0-31): 32 x LlamaLikeBlock(
          (norm_1): FasterTransformerRMSNorm()
          (attn): QuantAttentionFused(
            (qkv_proj): WQLinear_GEMM(in_features=4096, out_features=6144, bias=False, w_bit=4, group_size=128)
            (o_proj): WQLinear_GEMM(in_features=4096, out_features=4096, bias=False, w_bit=4, group_size=128)
            (rope): RoPE()
          )
          (norm_2): FasterTransformerRMSNorm()
          (mlp): MistralMLP(
            (gate_proj): WQLinear_GEMM(in_features=4096, out_features=14336, bias=False, w_bit=4, group_size=128)
            (up_proj): WQLinear_GEMM(in_features=4096, out_features=14336, bias=False, w_bit=4, group_size=128)
            (down_proj): WQLinea

[<bound method MultiTokenStoppingCriteria.reset of <llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fa49049f0a0>>]

In [8]:
new_estimator.callbacks_after_inference

[<bound method MultiTokenStoppingCriteria.reset of <llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fa49049f0a0>>]