In [1]:
"""
clone of nb 10 - check inf of slm and benchmark on dataset running on run pod

- copybara openhermes 2.5 model llmsearch on gsm8k

install exllama
wget https://github.com/turboderp/exllamav2/releases/download/v0.0.14/exllamav2-0.0.14+cu121-cp310-cp310-linux_x86_64.whl
pip install -q exllamav2-0.0.14+cu121-cp310-cp310-linux_x86_64.whl
"""


# Autocompletion
%config Completer.use_jedi = False

# Autoreload
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/workspace/llmsearch')

In [3]:
from llmsearch.tuner import Tuner

import gc
import torch
import ctypes

import nltk
import torch
import random
import evaluate
import datasets
import langchain
import numpy as np
import transformers
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, StoppingCriteria, AutoTokenizer, StoppingCriteriaList

import os
import gc
import ctypes
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union, List

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

class SingleTokenStoppingCriteria(StoppingCriteria):
    """End generation if end token is encountered
    does not support batched implementation yet"""

    def __init__(self, token_id):
      super().__init__()
      self.token_id =  token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        res = []

        last_token_id = input_ids[0][-1]
        if last_token_id == self.token_id:
            return True
        return False

Monkey Patching .generate function of `transformers` library


In [4]:
from llmsearch.tuner import Tuner

import gc
import torch
import ctypes

import nltk
import torch
import random
import evaluate
import datasets
import langchain
import numpy as np
import transformers
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, StoppingCriteria, AutoTokenizer, StoppingCriteriaList


In [5]:
seed = 42
device = "cuda:0"
seed_everything(seed=seed)
os.environ['HF_TOKEN'] = ""

In [9]:
from llmsearch.tuner import Tuner

import gc
import torch
import ctypes

import nltk
import torch
import random
import evaluate
import datasets
import langchain
import numpy as np
import transformers
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, StoppingCriteria, AutoTokenizer, StoppingCriteriaList
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Config
)

import os
import gc
import ctypes
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union, List


class Exllamav2HF(PreTrainedModel):
    def __init__(self, config: ExLlamaV2Config):
        super().__init__(PretrainedConfig())
        self.ex_config = config
        self.ex_model = ExLlamaV2(config)
        split = None
        if shared.args.gpu_split:
            split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

        self.ex_model.load(split)
        self.generation_config = GenerationConfig()
        self.loras = None

        if shared.args.cache_8bit:
            self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
        else:
            self.ex_cache = ExLlamaV2Cache(self.ex_model)

        self.past_seq = None
        if shared.args.cfg_cache:
            if shared.args.cache_8bit:
                self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
            else:
                self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)

            self.past_seq_negative = None

    def _validate_model_class(self):
        pass

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        pass

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}

    @property
    def device(self) -> torch.device:
        return torch.device(0)

    def __call__(self, *args, **kwargs):
        use_cache = kwargs.get('use_cache', True)
        labels = kwargs.get('labels', None)
        past_key_values = kwargs.get('past_key_values', None)

        if len(args) > 0:
            if not shared.args.cfg_cache:
                print("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
                return

            input_ids = args[0]
            is_negative = True
            past_seq = self.past_seq_negative
            ex_cache = self.ex_cache_negative
        else:
            input_ids = kwargs['input_ids']
            is_negative = False
            past_seq = self.past_seq
            ex_cache = self.ex_cache

        seq = input_ids[0].tolist()
        if is_negative and past_key_values is not None:
            seq = past_key_values + seq

        seq_tensor = torch.tensor(seq)
        reset = True

        # Make the forward call
        if labels is None:
            if past_seq is not None:
                min_length = min(past_seq.shape[0], seq_tensor.shape[0])
                indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
                if len(indices) > 0:
                    longest_prefix = indices[0].item()
                else:
                    longest_prefix = min_length

                if longest_prefix > 0:
                    reset = False
                    ex_cache.current_seq_len = longest_prefix
                    if len(seq_tensor) - longest_prefix > 1:
                        self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
                    elif len(seq_tensor) == longest_prefix:
                        # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
                        # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
                        ex_cache.current_seq_len -= 1

            if reset:
                ex_cache.current_seq_len = 0
                if len(seq_tensor) > 1:
                    self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)

            logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
        else:
            ex_cache.current_seq_len = 0
            logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()

        if is_negative:
            self.past_seq_negative = seq_tensor
        else:
            self.past_seq = seq_tensor

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
        assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)


        config = ExLlamaV2Config()
        config.model_dir = str(pretrained_model_name_or_path)
        config.prepare()

        config.max_seq_len = shared.args.max_seq_len
        config.scale_pos_emb = shared.args.compress_pos_emb
        config.scale_alpha_value = shared.args.alpha_value
        config.no_flash_attn = shared.args.no_flash_attn

        return Exllamav2HF(config)

class SingleTokenStoppingCriteria(StoppingCriteria):
    """End generation if end token is encountered
    does not support batched implementation yet"""

    def __init__(self, token_id):
      super().__init__()
      self.token_id =  token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        res = []

        last_token_id = input_ids[0][-1]
        if last_token_id == self.token_id:
            return True
        return False


def cm():
    gc.collect()
    ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

class Shared:
    class Args:
        def __init__(self):
            self.gpu_split = None

    def __init__(self):
        self.args = Shared.Args()

def setup_dataset(tokenizer):
    valid_dataset_with_out = datasets.load_dataset("samsum")['validation']
    valid_dataset_with_out = DatasetWrapper(valid_dataset_with_out, tokenizer, input_key = "dialogue", output_key = "summary")

    valid_dataset_without_out = datasets.load_dataset("samsum")['validation']
    valid_dataset_without_out = DatasetWrapper(valid_dataset_without_out, tokenizer, input_key = "dialogue", output_key = "summary", add_output = False)


    test_dataset = datasets.load_dataset("samsum")['test']
    test_dataset = DatasetWrapper(test_dataset, tokenizer, input_key = "dialogue", output_key = "summary", add_output = False)
    test_dataset.apply_chat_template(add_gen_prompt=True)

    valid_dataset_with_out.apply_chat_template(add_gen_prompt=False)
    # valid_dataset_with_out.tokenize_dataset()

    valid_dataset_without_out.apply_chat_template(add_gen_prompt=True)

    return valid_dataset_without_out, test_dataset

class DatasetWrapper:
    def __init__(self, hf_dataset, tokenizer, prompt_template = "Summarize : {dialogue}", input_key = "", output_key = "", system_prompt = "", add_output = True):
        self.tokenizer = tokenizer
        self.hf_dataset = hf_dataset
        self.hf_dataset = self.hf_dataset.map(lambda x : {"chat_format" : ([{'role' : "system", "content" : system_prompt}] if system_prompt else []) + [
            {
                'role' : "user", "content" : prompt_template.format(**{input_key : x[input_key]})
            }
        ] + ([{'role' : 'assistant', "content" : x[output_key]}] if add_output else [])})

    def apply_chat_template(self, add_gen_prompt = True):
        """Converts the dataset to a chat based format"""
        self.hf_dataset = self.hf_dataset.map(lambda x: {"formatted_chat": self.tokenizer.apply_chat_template(x["chat_format"], tokenize=False, add_generation_prompt=add_gen_prompt)})

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
    return preds, labels

def perform_single_example_inference(example, model, gen_kwargs):

    tokenized_input = tokenizer(example, return_tensors = "pt", add_special_tokens = False)
    tokenized_input['input_ids'] = tokenized_input['input_ids'].to('cuda:0')

    tokenized_input['attention_mask'] = tokenized_input['attention_mask'].to('cuda:0')
    # tokenized_input.to(device)
    # print(tokenized_input)

    model_out = model.generate(**tokenized_input, **gen_kwargs)
    prompt_tokens = len(tokenized_input['input_ids'][0])
    print(f"Prompt tokens - {prompt_tokens}")
    # print(model_out.tolist()[0])

    output_token_ids = model_out.tolist()[0]
    decoded_output = tokenizer.decode(output_token_ids, spaces_between_special_tokens = False)

    print(decoded_output)
    completion_tokens = len(output_token_ids) - prompt_tokens

    print(f"Completion Tokens - {completion_tokens}")

    return decoded_output


def get_rouge_score(y_true: List, y_pred: List):
    preds, gts = postprocess_text(preds=y_pred, labels=y_true)

    result = rouge_metric.compute(predictions=preds, references=gts, use_stemmer=True, use_aggregator=False)
    return np.mean(result['rouge2'])

def load_model(model_dir):
    pass

shared = Shared()
shared.args.gpu_split = None
shared.args.cache_8bit = None
shared.args.cfg_cache = None
# shared.args.model_dir = "/kaggle/input/"
shared.args.max_seq_len = 2048
shared.args.compress_pos_emb = 1
shared.args.alpha_value = 1
shared.args.no_flash_attn = 1

In [10]:
import torch
torch.__version__

'2.2.0+cu121'

In [13]:
# load model
model_dir = '/workspace/capybarahermes-2.5-gptq/TheBloke_CapybaraHermes-2.5-Mistral-7B-GPTQ/'
model = Exllamav2HF.from_pretrained(pretrained_model_name_or_path = model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False, legacy=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
tokenizer.chat_template

"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [16]:
from datasets import load_dataset

gsm8k_dataset = load_dataset("gsm8k", 'main')

Downloading readme:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 2.31M/2.31M [00:00<00:00, 6.41MB/s]
Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 427kB/s]


Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [18]:
stopping_criteria = StoppingCriteriaList([SingleTokenStoppingCriteria(token_id=32000)])

In [66]:
import langchain

text = """\
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: {question}"""

"""
- decide which metric to use
- add in evaluation for that metric
- run dummy eval on a small set
- then run search
"""


pt = langchain.PromptTemplate.from_template(text)

idx = 3
sample = gsm8k_dataset['train'][idx]

question = sample['question']
answer = sample['answer']

formatted_pt = pt.format(question=question)

print(formatted_pt)

Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?


In [67]:
messages = [
    {
        "role": "system",
        "content": "You are a friendly assistant who can solve math problems",
    },
    {"role": "user", "content": formatted_pt},
]

ct_sample = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt=True)
print(ct_sample)
# print(question)
print(f"Answer - {answer}\n")

<|im_start|>system
You are a friendly assistant who can solve math problems<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?<|im_end|>
<|im_start|>assistant

Answer - Maila read 12 x 2 = <<12*2=24>>24 pages today.
So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since ye

In [68]:
%%time
gen_params1 = {
    'max_new_tokens' : 200,
    'generation_seed' : 42,
    'stopping_criteria' : stopping_criteria,
#     'temperature' : 0.1
#     'do_sample' : True,
}
output = perform_single_example_inference(ct_sample, model, gen_params1)

Prompt tokens - 240
<|im_start|>system
You are a friendly assistant who can solve math problems<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?<|im_end|>
<|im_start|>assistant
A: Julie read 12 pages yesterday. Today, she read twice as many pages, which is 12 * 2 = 24 pages. So far, she 

In [69]:
print(output)

<|im_start|>system
You are a friendly assistant who can solve math problems<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?<|im_end|>
<|im_start|>assistant
A: Julie read 12 pages yesterday. Today, she read twice as many pages, which is 12 * 2 = 24 pages. So far, she has read 12 + 24 = 3

In [None]:
def perform_single_example_inference(example, model, gen_kwargs):

    tokenized_input = tokenizer(example, return_tensors = "pt", add_special_tokens = False)
    tokenized_input['input_ids'] = tokenized_input['input_ids'].to('cuda:0')

    tokenized_input['attention_mask'] = tokenized_input['attention_mask'].to('cuda:0')
    # tokenized_input.to(device)
    # print(tokenized_input)

    model_out = model.generate(**tokenized_input, **gen_kwargs)
    prompt_tokens = len(tokenized_input['input_ids'][0])
    print(f"Prompt tokens - {prompt_tokens}")
    # print(model_out.tolist()[0])

    output_token_ids = model_out.tolist()[0]
    decoded_output = tokenizer.decode(output_token_ids, spaces_between_special_tokens = False)

    print(decoded_output)
    completion_tokens = len(output_token_ids) - prompt_tokens

    print(f"Completion Tokens - {completion_tokens}")

    return decoded_output

In [None]:
idx = 0
gsm8k_dataset

In [None]:
example = valid_dataset.hf_dataset[3]['formatted_chat']
print(example)