In [1]:
"""

Check how the model pads input
Check how different enecoding & decoding params affect the encoding & decoding

"""

# Autocompletion
%config Completer.use_jedi = False

# Autoreload
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/workspace/llmsearch')

import gc
import torch
import ctypes
import json
import nltk
import math
import torch
import random
import evaluate
import datasets
import langchain
import numpy as np
import pandas as pd
import transformers
from transformers import GPTQConfig, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, StoppingCriteria, AutoTokenizer, StoppingCriteriaList, AutoModel, AutoModelForCausalLM

import os
import gc
import ctypes
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union, List

import time
import textwrap
from tqdm.auto import tqdm

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Config
)

from datasets import load_dataset
from llmsearch.model_downloader import download_model_from_hf
from llmsearch.utils.model_utils import batcher, decoder_parser

import awq

from awq import AutoAWQForCausalLM

def pretty_print_dict(d, indent = 4):
    print(json.dumps(d, indent = indent, default = str))

Monkey Patching .generate function of `transformers` library


In [2]:
gsm8k_dataset = load_dataset("gsm8k", 'main')

torch.__version__, awq.__version__

('2.2.0+cu121', '0.2.4')

In [3]:

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


class SingleTokenStoppingCriteria(StoppingCriteria):
    """End generation if end token is encountered
    does not support batched implementation yet"""

    def __init__(self, token_id):
      super().__init__()
      self.token_id =  token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        res = []

        last_token_id = input_ids[0][-1]
        if last_token_id == self.token_id:
            return True
        return False


def cm():
    gc.collect()
    ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False



def perform_single_example_inference(example, model, tokenizer,gen_kwargs):

    tokenized_input = tokenizer(example, return_tensors = "pt", add_special_tokens = False)
    tokenized_input['input_ids'] = tokenized_input['input_ids'].to('cuda:0')

    tokenized_input['attention_mask'] = tokenized_input['attention_mask'].to('cuda:0')
    # tokenized_input.to(device)
    # print(tokenized_input)

    model_out = model.generate(**tokenized_input, **gen_kwargs)
    prompt_tokens = len(tokenized_input['input_ids'][0])
    print(f"Prompt tokens - {prompt_tokens}")
    # print(model_out.tolist()[0])

    output_token_ids = model_out.tolist()[0]
    decoded_output = tokenizer.decode(output_token_ids, spaces_between_special_tokens = False)

    print(decoded_output)
    completion_tokens = len(output_token_ids) - prompt_tokens

    print(f"Completion Tokens - {completion_tokens}")

    return decoded_output, prompt_tokens, completion_tokens

In [4]:
# loaders

class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence.

    This code is not thread safe. The same object cannot be used simultaneously in multiple threads.
    """

    def __init__(
        self,
        sequence_ids : List[int],
    ) -> None:
        self.sequence_ids = torch.tensor(sequence_ids, dtype = torch.int32, device = "cuda:0")
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization
        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
        # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
        self.sequence_id_len = self.sequence_ids.shape[0] + 2
        self.state_initialized = False
        self.input_length = None
        self.state_initialized = False

    def set_state(self, batch_size, input_length):
        self.batch_size = batch_size
        self.input_length = input_length
        self.done_tracker = [False] * batch_size
        self.state_initialized = True

    def reset(self):
        self.batch_size = None
        self.input_length = None
        self.state_initialized = False


    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence

        ret_val = False

        if not self.state_initialized:
            # 1st call to __call__ for this batch
            self.set_state(input_ids.shape[0], input_ids.shape[1])

        # IDs of all the tokens except the prompt
        lookback_ids_batch = input_ids[:, self.input_length :]
        # look back for 2 more tokens than it takes to encode our stop sequence
        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]

        # no elements yet to look back
        if lookback_ids_batch.nelement() == 0:
            return False

        for i, done in enumerate(self.done_tracker):
            if not done:
                # look back only as far as the last token of the stop sequence
                self.done_tracker[i] = self.sequence_ids == lookback_ids_batch[i][-(self.sequence_ids.shape[0]):]
        ret_val = False not in self.done_tracker
        if ret_val:
            # print(f"finish, ", self.sequence_ids, lookback_ids_batch)
            self.reset()
        return ret_val


def load_model_with_awq_backend(model_id, model_loader_kwargs, tokenizer_kwargs,temp_model_dir, model_branch = "main"):
    output_folder = download_model_from_hf(model_id, save_dir = temp_model_dir, branch = model_branch)

    model_loader_kwargs['pretrained_model_name_or_path'] = output_folder
    tokenizer_loader_kwargs['pretrained_model_name_or_path'] = output_folder

    model_name_or_path = model_loader_kwargs.pop('pretrained_model_name_or_path')
    model = AutoAWQForCausalLM.from_quantized(
        quant_path=model_name_or_path,
        **model_loader_kwargs
    )
    tokenizer = AutoTokenizer.from_pretrained(**tokenizer_kwargs, local_files_only=True)

    # pad token is null in config -https://huggingface.co/TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ/blob/eb64c310c44905321d012962db9ac0d47c3a64fa/tokenizer_config.json#L53
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

model_loader_backend_map = {
    # "exllama_2_hf": load_model_with_exllama_2_hf_backend,
    # "hf": load_model_with_hf_backend,
    # 'auto_gptq' : load_model_with_autogptq_backend,
    'awq' : load_model_with_awq_backend,
}

In [5]:
# https://huggingface.co/TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ
model_id = "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ"

temp_model_dir = Path(f"/workspace/temp_model_dir/")
temp_model_dir.mkdir(exist_ok = True, parents = True)

model_loader_kwargs = {
    'device_map' : {'' : 0},
    'fuse_layers' : True,
}

tokenizer_loader_kwargs = {
    'use_fast' : False,
    'legacy' : False,
    'padding_side' : 'left',
}

model, tokenizer = load_model_with_awq_backend(model_id, model_loader_kwargs, tokenizer_loader_kwargs,temp_model_dir, model_branch = "main")

Model already exists in /workspace/temp_model_dir/TheBloke_CapybaraHermes-2.5-Mistral-7B-AWQ. Checking the model files...
Checksum validated: model.safetensors  645dfc7f09074aaf25e642f3c6a4f7ea399a0ff2605fa650e4e74078832546de
Checksum validated: tokenizer.model  dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
[+] Validated checksums of all model files!


Replacing layers...: 100%|██████████| 32/32 [00:03<00:00,  8.63it/s]
Fusing layers...: 100%|██████████| 32/32 [00:02<00:00, 12.92it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def preprocess_dataset(dataset, tokenizer, pt, pt_cols, system_prompt, add_generation_prompt = True):

    def wrapper(sample):
        """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
        messages = [] if system_prompt is None else [{"role": "system", "content": system_prompt}]
        formatted_pt = pt.format(**{pt_col : sample[pt_col] for pt_col in pt_cols})
        messages.append(
            {
                "role": "user",
                "content": formatted_pt,
            }
        )
        formatted_pt_with_ct = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt=add_generation_prompt)
        return formatted_pt_with_ct

    def actual_input(sample):
        """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
        return sample[pt_cols[0]]



    pt_dataset = dataset.map(
        lambda sample : {
            "X" : wrapper(sample),
            'actual input' : actual_input(sample),
        }
    )

    return pt_dataset

In [7]:
pt = textwrap.dedent("""\
    Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
    A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

    Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
    A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

    Q: {question}""")
pt_cols = ['question']
system_prompt = "Solve the following math problems, end with The answer is"

# Add prompt template
processed_dataset = preprocess_dataset(gsm8k_dataset['train'], tokenizer,pt = pt, pt_cols = pt_cols, system_prompt = system_prompt, add_generation_prompt = True)

In [8]:
seed = 42
bm_sample_size = 10
bm_samples = processed_dataset.shuffle(seed = seed).select(range(bm_sample_size))

In [9]:
bm_samples[0]

{'question': 'Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?',
 'answer': 'Mimi has 2 x 12 = <<2*12=24>>24 sea shells.\nKyle has 24 x 2 = <<24*2=48>>48 sea shells.\nLeigh has 48 / 3 = <<48/3=16>>16 sea shells.\n#### 16',
 'X': '<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answ

In [10]:
import re

def extract_answer_from_out(s):
    pattern = re.compile(r"The answer is (\d+(?:\.\d+)?)")
    match = pattern.search(s)
    if match:
        return match.group(1).strip()
    else:
        return None

def get_score(y_true, y_pred):
    scores = []

    for y_t, y_p in zip(y_true, y_pred):
        y_t_answer = y_t['answer'].split("####")[-1].strip()
        y_p_answer = extract_answer_from_out(y_p)


        if y_t_answer == y_p_answer:
            scores.append(1)
        else:
            scores.append(0)
    return sum(scores)/len(scores)

In [11]:
from llmsearch.utils.logging_utils import set_verbosity_info, set_verbosity_debug, set_verbosity_warning
set_verbosity_debug()

In [12]:
from llmsearch.tuner import Tuner

cm()

tuner_ob = Tuner(
    model = model,
    tokenizer = tokenizer,
    dataset = bm_samples,
    device = 'cuda:0',
    batch_size = 2,
    tokenizer_encoding_kwargs={'padding': 'longest', 'add_special_tokens' : False},
    tokenizer_decoding_kwargs={'spaces_between_special_tokens' : False},
    scorer = get_score,
    prompt_template = langchain.PromptTemplate.from_template("{X}"),
    is_encoder_decoder = False,
    seed = seed,
    column_mapping = {'input_cols' : ["X"],'eval_cols' : ['answer']},
)

2024-03-25 02:43:11.752 - llmsearch.tuner.tuner:82 - DEBUG - Initializing new estimator with generation parameters - {}


In [23]:
stopping_criteria = StoppingCriteriaList([MultiTokenEOSCriteria(sequence_ids = [32000])])

gen_params1 = {
    'max_new_tokens' : 500,
    'stopping_criteria' : stopping_criteria,
    'generation_seed' : 42,
}

scores_before, outputs_before = tuner_ob.get_score(gen_params1)

2024-03-25 02:28:45.648 - llmsearch.utils.mem_utils:153 - INFO - Starting inference with generation parameters - {'max_new_tokens': 500, 'stopping_criteria': [<__main__.MultiTokenEOSCriteria object at 0x7fe9144da3e0>], 'generation_seed': 42}
2024-03-25 02:28:45.649 - llmsearch.utils.mem_utils:157 - INFO - Performing inference with batch_size - 1
2024-03-25 02:28:45.651 - llmsearch.utils.model_utils:97 - INFO - Detected generation type - Greedy Decoding


  0%|          | 0/10 [00:00<?, ?it/s]

2024-03-25 02:29:18.198 - llmsearch.utils.model_utils:132 - DEBUG - Input - '<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\nQ: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|im_end|>\n<|im_start|>assistant\n'
2024-03-25 02:29:18.199 - llmsearch.utils.model_utils

In [25]:
scores_before

0.6

In [16]:
# decoder parser is working as expected
# TODO : check scores at different bs then llmsearch

stopping_criteria = StoppingCriteriaList([MultiTokenEOSCriteria(sequence_ids = [32000])])

gen_params1 = {
    'max_new_tokens' : 500,
    'stopping_criteria' : stopping_criteria,
    'generation_seed' : 42,
}

scores_before, outputs_before = tuner_ob.get_score(gen_params1)

scores_before

2024-03-25 02:50:32.804 - llmsearch.utils.mem_utils:153 - INFO - Starting inference with generation parameters - {'max_new_tokens': 500, 'stopping_criteria': [<__main__.MultiTokenEOSCriteria object at 0x7f8a44537d00>], 'generation_seed': 42}
2024-03-25 02:50:32.806 - llmsearch.utils.mem_utils:157 - INFO - Performing inference with batch_size - 2
2024-03-25 02:50:32.808 - llmsearch.utils.model_utils:97 - INFO - Detected generation type - Greedy Decoding


  0%|          | 0/5 [00:00<?, ?it/s]

here
<|im_start|>system
Solve the following math problems, end with The answer is<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|im_end|>
<|im_start|>assistant
A: Mimi picked up 2 dozen seashells, which is 2 * 12 = 24 seashells. Kyle found twice as many shells as Mimi, so he found 24 * 2 = 48 seas

2024-03-25 02:50:53.971 - llmsearch.utils.model_utils:135 - DEBUG - Input - '<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\nQ: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|im_end|>\n<|im_start|>assistant\n'
2024-03-25 02:50:53.973 - llmsearch.utils.model_utils

<|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|> <|im_start|>system
Solve the following math problems, end with The answer is<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Pauly is making omelets for his family. There are three dozen eggs, and he plans to use the

0.4