In [1]:
"""

Check how the model pads input
Check how different enecoding & decoding params affect the encoding & decoding

"""

# Autocompletion
%config Completer.use_jedi = False

# Autoreload
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/workspace/llmsearch')

import gc
import torch
import ctypes
import json
import nltk
import math
import torch
import random
import evaluate
import datasets
import langchain
import numpy as np
import pandas as pd
import transformers
from transformers import GPTQConfig, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, StoppingCriteria, AutoTokenizer, StoppingCriteriaList, AutoModel, AutoModelForCausalLM

import os
import gc
import ctypes
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union, List

import time
import textwrap
from tqdm.auto import tqdm

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Config
)

from datasets import load_dataset
from llmsearch.model_downloader import download_model_from_hf
from llmsearch.utils.model_utils import batcher, decoder_parser

import awq

from awq import AutoAWQForCausalLM

def pretty_print_dict(d, indent = 4):
    print(json.dumps(d, indent = indent, default = str))

Monkey Patching .generate function of `transformers` library


In [2]:
gsm8k_dataset = load_dataset("gsm8k", 'main')

torch.__version__, awq.__version__

Downloading readme:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 2.31M/2.31M [00:00<00:00, 5.57MB/s]
Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 1.97MB/s]


Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

('2.2.0+cu121', '0.2.4')

In [3]:

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


class SingleTokenStoppingCriteria(StoppingCriteria):
    """End generation if end token is encountered
    does not support batched implementation yet"""

    def __init__(self, token_id):
      super().__init__()
      self.token_id =  token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        res = []

        last_token_id = input_ids[0][-1]
        if last_token_id == self.token_id:
            return True
        return False


def cm():
    gc.collect()
    ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False



def perform_single_example_inference(example, model, tokenizer,gen_kwargs):

    tokenized_input = tokenizer(example, return_tensors = "pt", add_special_tokens = False)
    tokenized_input['input_ids'] = tokenized_input['input_ids'].to('cuda:0')

    tokenized_input['attention_mask'] = tokenized_input['attention_mask'].to('cuda:0')
    # tokenized_input.to(device)
    # print(tokenized_input)

    model_out = model.generate(**tokenized_input, **gen_kwargs)
    prompt_tokens = len(tokenized_input['input_ids'][0])
    print(f"Prompt tokens - {prompt_tokens}")
    # print(model_out.tolist()[0])

    output_token_ids = model_out.tolist()[0]
    decoded_output = tokenizer.decode(output_token_ids, spaces_between_special_tokens = False)

    print(decoded_output)
    completion_tokens = len(output_token_ids) - prompt_tokens

    print(f"Completion Tokens - {completion_tokens}")

    return decoded_output, prompt_tokens, completion_tokens

In [4]:
# loaders

class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence.

    This code is not thread safe. The same object cannot be used simultaneously in multiple threads.
    """

    def __init__(
        self,
        sequence_ids : List[int],
    ) -> None:
        self.sequence_ids = torch.tensor(sequence_ids, dtype = torch.int32, device = "cuda:0")
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization
        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
        # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
        self.sequence_id_len = self.sequence_ids.shape[0] + 2
        self.state_initialized = False
        self.input_length = None
        self.state_initialized = False

    def set_state(self, batch_size, input_length):
        self.batch_size = batch_size
        self.input_length = input_length
        self.done_tracker = [False] * batch_size
        self.state_initialized = True

    def reset(self):
        self.batch_size = None
        self.input_length = None
        self.state_initialized = False


    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence

        ret_val = False

        if not self.state_initialized:
            # 1st call to __call__ for this batch
            self.set_state(input_ids.shape[0], input_ids.shape[1])

        # IDs of all the tokens except the prompt
        lookback_ids_batch = input_ids[:, self.input_length :]
        # look back for 2 more tokens than it takes to encode our stop sequence
        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]

        # no elements yet to look back
        if lookback_ids_batch.nelement() == 0:
            return False

        for i, done in enumerate(self.done_tracker):
            if not done:
                # look back only as far as the last token of the stop sequence
                self.done_tracker[i] = self.sequence_ids == lookback_ids_batch[i][-(self.sequence_ids.shape[0]):]
        ret_val = False not in self.done_tracker
        if ret_val:
            # print(f"finish, ", self.sequence_ids, lookback_ids_batch)
            self.reset()
        return ret_val


def load_model_with_awq_backend(model_id, model_loader_kwargs, tokenizer_kwargs,temp_model_dir, model_branch = "main"):
    output_folder = download_model_from_hf(model_id, save_dir = temp_model_dir, branch = model_branch)

    model_loader_kwargs['pretrained_model_name_or_path'] = output_folder
    tokenizer_loader_kwargs['pretrained_model_name_or_path'] = output_folder

    model_name_or_path = model_loader_kwargs.pop('pretrained_model_name_or_path')
    model = AutoAWQForCausalLM.from_quantized(
        quant_path=model_name_or_path,
        **model_loader_kwargs
    )
    tokenizer = AutoTokenizer.from_pretrained(**tokenizer_kwargs, local_files_only=True)

    # pad token is null in config -https://huggingface.co/TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ/blob/eb64c310c44905321d012962db9ac0d47c3a64fa/tokenizer_config.json#L53
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

model_loader_backend_map = {
    # "exllama_2_hf": load_model_with_exllama_2_hf_backend,
    # "hf": load_model_with_hf_backend,
    # 'auto_gptq' : load_model_with_autogptq_backend,
    'awq' : load_model_with_awq_backend,
}

In [5]:
# https://huggingface.co/TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ
model_id = "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ"

temp_model_dir = Path(f"/workspace/temp_model_dir/")
temp_model_dir.mkdir(exist_ok = True, parents = True)

model_loader_kwargs = {
    'device_map' : {'' : 0},
    'fuse_layers' : True,
}

tokenizer_loader_kwargs = {
    'use_fast' : False,
    'legacy' : False,
    'padding_side' : 'left',
}

model, tokenizer = load_model_with_awq_backend(model_id, model_loader_kwargs, tokenizer_loader_kwargs,temp_model_dir, model_branch = "main")

Downloading the model to /workspace/temp_model_dir/TheBloke_CapybaraHermes-2.5-Mistral-7B-AWQ


100%|██████████| 17.9k /17.9k  29.9MiB/s
100%|██████████| 911   /911    2.65MiB/s
100%|██████████| 51.0  /51.0   47.3kiB/s
100%|██████████| 115   /115    14.4kiB/s
100%|██████████| 126   /126    271kiB/s
100%|██████████| 420   /420    935kiB/s
  0%|          | 0.00  /1.80M  ?iB/s 
100%|██████████| 1.60k /1.60k  4.44MiB/s
 58%|█████▊    | 1.05M /1.80M  5.02MiB/s
100%|██████████| 1.80M /1.80M  5.43MiB/s
100%|██████████| 493k  /493k   27.5MiB/s
100%|██████████| 4.15G /4.15G  265MiB/s
Replacing layers...: 100%|██████████| 32/32 [00:04<00:00,  7.39it/s]
Fusing layers...: 100%|██████████| 32/32 [00:02<00:00, 11.71it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def preprocess_dataset(dataset, tokenizer, pt, pt_cols, system_prompt, add_generation_prompt = True):

    def wrapper(sample):
        """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
        messages = [] if system_prompt is None else [{"role": "system", "content": system_prompt}]
        formatted_pt = pt.format(**{pt_col : sample[pt_col] for pt_col in pt_cols})
        messages.append(
            {
                "role": "user",
                "content": formatted_pt,
            }
        )
        formatted_pt_with_ct = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt=add_generation_prompt)
        return formatted_pt_with_ct

    def actual_input(sample):
        """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
        return sample[pt_cols[0]]



    pt_dataset = dataset.map(
        lambda sample : {
            "X" : wrapper(sample),
            'actual input' : actual_input(sample),
        }
    )

    return pt_dataset

In [7]:
pt = textwrap.dedent("""\
    Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
    A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

    Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
    A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

    Q: {question}""")
pt_cols = ['question']
system_prompt = "Solve the following math problems, end with The answer is"

# Add prompt template
processed_dataset = preprocess_dataset(gsm8k_dataset['train'], tokenizer,pt = pt, pt_cols = pt_cols, system_prompt = system_prompt, add_generation_prompt = True)

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

In [8]:
seed = 42
bm_sample_size = 100
bm_samples = processed_dataset.shuffle(seed = seed).select(range(bm_sample_size))

In [9]:
tokenizer.special_tokens_map, tokenizer.clean_up_tokenization_spaces

({'bos_token': '<s>',
  'eos_token': '<|im_end|>',
  'unk_token': '<unk>',
  'pad_token': '<|im_end|>'},
 False)

In [10]:
print("Processed Dataset:\n")
for i in range(5):
    print(processed_dataset[i]['X'])
    print('\n')
    print('---' * 10)
    print('\n')

Processed Dataset:

<|im_start|>system
Solve the following math problems, end with The answer is<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|im_end|>
<|im_start|>assistant



------------------------------


<|im_start|>system
Solve the following math problems, end with The answer is<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove

In [11]:
bm_samples[:2]

{'question': ['Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?',
  "Frankie's parents let him have many pets. He has six more snakes than he has cats. He has one less parrot than cats. Six of his pets have four legs. He has 2 dogs. How many pets does he have in total?"],
 'answer': ['Mimi has 2 x 12 = <<2*12=24>>24 sea shells.\nKyle has 24 x 2 = <<24*2=48>>48 sea shells.\nLeigh has 48 / 3 = <<48/3=16>>16 sea shells.\n#### 16',
  'He has 6 - 2 = <<6-2=4>>4 cats.\nHe has 4 - 1 = <<4-1=3>>3 parrots.\nHe has 4 + 6 = <<4+6=10>>10 snakes.\nHe has a total of 2 + 4 + 3 + 10 = <<2+4+3+10=19>>19 pets.\n#### 19'],
 'X': ['<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there w

In [45]:
from llmsearch.utils.model_utils import decoder_parser

def perform_inference_from_encoded_input(encoded_input, model, tokenizer, gen_kwargs):

    tokenized_input = encoded_input
    tokenized_input['input_ids'] = tokenized_input['input_ids'].to('cuda:0')

    tokenized_input['attention_mask'] = tokenized_input['attention_mask'].to('cuda:0')
    # tokenized_input.to(device)
    # print(tokenized_input)

    model_out = model.generate(**tokenized_input, **gen_kwargs)
    prompt_tokens = len(tokenized_input['input_ids'][0])
    print(f"Prompt tokens - {prompt_tokens}")
    # print(model_out.tolist()[0])

    output_token_ids = model_out.tolist()[0]
    decoded_output = tokenizer.decode(output_token_ids, spaces_between_special_tokens = False)
    decoded_input = tokenizer.decode(encoded_input['input_ids'][0], spaces_between_special_tokens = False)

    print(decoded_output)
    completion_tokens = len(output_token_ids) - prompt_tokens

    out = decoder_parser(outputs = [decoded_output], formatted_prompts = [decoded_input], prepoc = lambda x : x.strip())

    print(f"Completion Tokens - {completion_tokens}")

    return out


def perform_single_example_inference(example, model, tokenizer, gen_kwargs):

    tokenized_input = tokenizer(example, return_tensors = "pt", add_special_tokens = False)
    tokenized_input['input_ids'] = tokenized_input['input_ids'].to('cuda:0')

    tokenized_input['attention_mask'] = tokenized_input['attention_mask'].to('cuda:0')
    # tokenized_input.to(device)
    # print(tokenized_input)

    model_out = model.generate(**tokenized_input, **gen_kwargs)
    prompt_tokens = len(tokenized_input['input_ids'][0])
    print(f"Prompt tokens - {prompt_tokens}")
    # print(model_out.tolist()[0])

    output_token_ids = model_out.tolist()[0]
    decoded_output = tokenizer.decode(output_token_ids, spaces_between_special_tokens = False)

    print(decoded_output)
    completion_tokens = len(output_token_ids) - prompt_tokens

    out = decoder_parser(outputs = [decoded_output], formatted_prompts = [example], prepoc = lambda x : x.strip())



    print(f"Completion Tokens - {completion_tokens}")

    return out

In [13]:
bm_samples[:2]

{'question': ['Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?',
  "Frankie's parents let him have many pets. He has six more snakes than he has cats. He has one less parrot than cats. Six of his pets have four legs. He has 2 dogs. How many pets does he have in total?"],
 'answer': ['Mimi has 2 x 12 = <<2*12=24>>24 sea shells.\nKyle has 24 x 2 = <<24*2=48>>48 sea shells.\nLeigh has 48 / 3 = <<48/3=16>>16 sea shells.\n#### 16',
  'He has 6 - 2 = <<6-2=4>>4 cats.\nHe has 4 - 1 = <<4-1=3>>3 parrots.\nHe has 4 + 6 = <<4+6=10>>10 snakes.\nHe has a total of 2 + 4 + 3 + 10 = <<2+4+3+10=19>>19 pets.\n#### 19'],
 'X': ['<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there w

In [14]:
from pprint import pprint

my_list = ['apple', 'banana', 'cherry', 'date', 'elderberry']

pprint(my_list)

['apple', 'banana', 'cherry', 'date', 'elderberry']


In [16]:
from llmsearch.utils.model_utils import batcher

def batch_inputs(inputs, tokenizer,batch_size, tokenizer_encode_args):
    batched_inputs = []


    for batch in tqdm(batcher(inputs, batch_size)):
        model_input = [item['X'] for item in batch]
        encoded_input = tokenizer(text = model_input, **tokenizer_encode_args, return_tensors = "pt")

        # TODO : When a batch is encoded an item in the list could be a batch, use torch.chunk to split and extend the list
        # final objective is to get a list of items how it would be encoded if batch size was some value
        for item in encoded_input['input_ids']:
            batched_inputs.append(item.tolist())
        # print(batched_inputs)

    return batched_inputs

# Batched computation output could be different than the single computation output - https://fancyerii.github.io/2024/01/17/padding-debug/#problem


batch_size = 4
sample_size = 4
tokenizer_encode_args = {
    # pad to longest seq in batch
    'padding' : 'longest',
    # adds <s> to the start (only adds to the longest sequence in the batch for some reason)
    'add_special_tokens' : False
}
# convert to a list of dicts
bm_sample_dicts = [{k: v[i] for k, v in bm_samples[:sample_size].items()} for i in range(sample_size)]
ct_inputs = [item['X'] for item in bm_sample_dicts]

batched_input = batch_inputs(bm_sample_dicts, tokenizer, batch_size,tokenizer_encode_args)

tokenizer_decode_args = {
    # no need of this since in encoding we are not adding special tokens
    # 'skip_special_tokens' : True,

    # did not make any diff
    # 'clean_up_tokenization_spaces' : True,
}


for idx, (encoded_batch, ct_input) in enumerate(zip(batched_input, ct_inputs)):
    # encoded batch len will change based on the batch size
    print(idx, len(encoded_batch))


    decoded_input = tokenizer.batch_decode(encoded_batch, **tokenizer_decode_args)

    print(ct_input)
    print(decoded_input)
    print('\n\n', '---' * 10, '\n\n')

    if idx == 3:
        break

# Findings - just use padding = 'longest' and add_special_tokens = False, no need to skip_special_tokens or clean_up_tokenization_spaces
# TODO : check model generation ouptut for the same sample when padded different ways and with no padding

# for idx, item in enumerate(tokenizer.batch_decode(batched_input['input_ids'])):
#     print(idx)
#     print(item)
#     print('\n\n', '---' * 10, '\n\n')

0it [00:00, ?it/s]

0 289
<|im_start|>system
Solve the following math problems, end with The answer is<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|im_end|>
<|im_start|>assistant

['<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end|>', '<|im_end

In [25]:
tokenizer.decode([28709])

'o'

In [31]:
ct_inputs

['<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\nQ: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|im_end|>\n<|im_start|>assistant\n',
 "<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 tr

In [55]:
def pad_text(text, tokenizer, pad_nos = 0):
    """pad text to a specific len"""
    encoded_input = tokenizer(text = text, add_special_tokens=False)

    encoded_input = tokenizer(text = text, padding = 'max_length', max_length = len(encoded_input['input_ids']) + pad_nos, add_special_tokens=False, return_tensors = "pt")

    decoded_input = tokenizer.decode(encoded_input['input_ids'][0])

    return encoded_input


gen_kwargs = {
    'max_new_tokens' : 500,
    'stopping_criteria' : StoppingCriteriaList([SingleTokenStoppingCriteria(tokenizer.eos_token_id)]),
}

# output deteoriates from 18 padding (0), Is it the same for other examples?
# for eg 3 - output deteoriates from pad = 3, has correct example within 10
for i in range(40):
    encoded_input = pad_text(ct_inputs[3], tokenizer, pad_nos = i)

    print(f"Padding : {i}")
    out = perform_inference_from_encoded_input(encoded_input, model, tokenizer, gen_kwargs)


    print(out)

    print('\n\n', '---' * 10, '\n\n')


# pad_text("hello", tokenizer, pad_nos = 4)

Padding : 0
Prompt tokens - 255
<|im_start|>system
Solve the following math problems, end with The answer is<|im_end|>
<|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Emma's bank account has $100 in it. Each day of the week, she spends $8. At the end of the week, she goes to the bank and asks for as many $5 bills as her account can give her. She leaves the rest in the account. How many dollars remain in the account?<|im_end|>
<|im_start|>assistant
A: Emma starts with $100 in her account. She spends $8 each day f

In [54]:
bm_sample_dicts[3]

{'question': "Emma's bank account has $100 in it. Each day of the week, she spends $8. At the end of the week, she goes to the bank and asks for as many $5 bills as her account can give her. She leaves the rest in the account. How many dollars remain in the account?",
 'answer': 'She spend $56 because 7 x 8 = <<7*8=56>>56\nShe has $44 left in the bank because 100 - 56 = <<100-56=44>>44\nShe can get 8 five dollar bills because 44 / 5 = <<44/5=8.8>>8.8\nThis is equal to $40 because 8 x 5 = <<8*5=40>>40\nShe has $4 left in the account because 44 - 40 = <<44-40=4>>4\n#### 4',
 'X': "<|im_start|>system\nSolve the following math problems, end with The answer is<|im_end|>\n<|im_start|>user\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6

In [54]:
stopping_criteria = StoppingCriteriaList([MultiTokenEOSCriteria(sequence_ids = [32000])])
gen_kwargs = {
    'max_new_tokens' : 500,
    'stopping_criteria' : stopping_criteria
}

# Output changes based on skip_special_tokens value
# padding tokens influencing output

out = perform_single_example_inference(tokenizer.decode(batched_input['input_ids'][2], skip_special_tokens=True), model, tokenizer, gen_kwargs)

Prompt tokens - 288
<|im_start|>system
Solve the following math problems, end with The answer is
 <|im_start|>user
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: Olaf collects colorful toy cars. At first, his collection consisted of 150 cars. His family, knowing his hobby, decided to give him some toy cars. Grandpa gave Olaf twice as many toy cars as the uncle. Dad gave Olaf 10 toy cars, 5 less than Mum. Auntie gave Olaf 6 toy cars, 1 more than the uncle. How many toy cars does Olaf have in total, after receiving all these gift

In [35]:
bm_samples['answer'][:4]

['Mimi has 2 x 12 = <<2*12=24>>24 sea shells.\nKyle has 24 x 2 = <<24*2=48>>48 sea shells.\nLeigh has 48 / 3 = <<48/3=16>>16 sea shells.\n#### 16',
 'He has 6 - 2 = <<6-2=4>>4 cats.\nHe has 4 - 1 = <<4-1=3>>3 parrots.\nHe has 4 + 6 = <<4+6=10>>10 snakes.\nHe has a total of 2 + 4 + 3 + 10 = <<2+4+3+10=19>>19 pets.\n#### 19',
 "Dad gave Olaf 10 toy cars,\nMom has given Olaf 5 more toy cars than Dad, so 10 + 5 = <<10+5=15>>15 toy cars\nAuntie gave Olaf 6 toy cars,\nUncle has given 1 less toy than Auntie, so 6 - 1 = <<6-1=5>>5 toy cars\nGrandpa gave Olaf 2 * 5 = <<2*5=10>>10 toy cars.\nAll the family together gave Olaf 10 +15 + 6 + 5 + 10 = <<10+15+6+5+10=46>>46.\nAdding the cars Olaf already had, Olaf's collection has 150 + 46 = <<150+46=196>>196 cars.\n#### 196",
 'She spend $56 because 7 x 8 = <<7*8=56>>56\nShe has $44 left in the bank because 100 - 56 = <<100-56=44>>44\nShe can get 8 five dollar bills because 44 / 5 = <<44/5=8.8>>8.8\nThis is equal to $40 because 8 x 5 = <<8*5=40>>40\n

In [None]:
batched_input['input_ids']

In [None]:
tokenizer.batch_decode(batched_input['input_ids'])