In [2]:
"""
Model - https://huggingface.co/cognitivecomputations/dolphin-2.2.1-mistral-7b fintuned using LORA on samsum validation set and then quantized to GPTQ to be used via exllamav2
dataset finetuned on - samsum validation set

llmsearch example shown on - samsum train set and evaluated on samsum test set

Requires:
nltk==3.8.1
rouge_score==0.1.2
py7zr=0.21.0
exllamav2@https://github.com/turboderp/exllamav2/releases/download/v0.0.14/exllamav2-0.0.14+cu121-cp310-cp310-linux_x86_64.whl
"""

import torch
import transformers

import llmsearch
import exllamav2
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Config
)

print(exllamav2.__version__, torch.__version__, transformers.__version__, llmsearch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


Monkey Patching .generate function of `transformers` library
0.0.14 2.2.0+cu121 4.38.2 0.1.0


In [3]:
import os
from pathlib import Path
from typing import Dict, Any, Optional, Union, List

import nltk
import datasets
import evaluate
import numpy as np

from auto_gptq import AutoGPTQForCausalLM
from llmsearch.tuner import Tuner
from llmsearch.utils.mem_utils import gc_cuda
from sklearn.model_selection import GridSearchCV
from llmsearch.utils.model_downloader import download_model_from_hf
from llmsearch.scripts.stopping_criteria import MultiTokenStoppingCriteria
from llmsearch.utils.logging_utils import set_verbosity_info, set_verbosity_debug, set_verbosity_warning
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, AutoTokenizer, StoppingCriteriaList, AutoModelForCausalLM

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
seed = 42
batch_size = 1
num_tune_samples = 1200
num_test_samples = 800

model_id = "Praful932/dolphin-2.2.1-mistral-7b-samsum-ft-v1-GPTQ"
device = "cuda:0"

In [5]:
# ------ Model related code ------
class Exllamav2HF(PreTrainedModel):
    def __init__(self, config: ExLlamaV2Config):
        super().__init__(PretrainedConfig())
        self.ex_config = config
        self.ex_model = ExLlamaV2(config)
        split = None
        if shared.args.gpu_split:
            split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

        self.ex_model.load(split)
        self.generation_config = GenerationConfig()
        self.loras = None

        if shared.args.cache_8bit:
            self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
        else:
            self.ex_cache = ExLlamaV2Cache(self.ex_model)

        self.past_seq = None
        if shared.args.cfg_cache:
            if shared.args.cache_8bit:
                self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
            else:
                self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)

            self.past_seq_negative = None

    def _validate_model_class(self):
        pass

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        pass

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}

    @property
    def device(self) -> torch.device:
        return torch.device(0)

    def __call__(self, *args, **kwargs):
        use_cache = kwargs.get('use_cache', True)
        labels = kwargs.get('labels', None)
        past_key_values = kwargs.get('past_key_values', None)

        if len(args) > 0:
            if not shared.args.cfg_cache:
                print("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
                return

            input_ids = args[0]
            is_negative = True
            past_seq = self.past_seq_negative
            ex_cache = self.ex_cache_negative
        else:
            input_ids = kwargs['input_ids']
            is_negative = False
            past_seq = self.past_seq
            ex_cache = self.ex_cache

        seq = input_ids[0].tolist()
        if is_negative and past_key_values is not None:
            seq = past_key_values + seq

        seq_tensor = torch.tensor(seq)
        reset = True

        # Make the forward call
        if labels is None:
            if past_seq is not None:
                min_length = min(past_seq.shape[0], seq_tensor.shape[0])
                indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
                if len(indices) > 0:
                    longest_prefix = indices[0].item()
                else:
                    longest_prefix = min_length

                if longest_prefix > 0:
                    reset = False
                    ex_cache.current_seq_len = longest_prefix
                    if len(seq_tensor) - longest_prefix > 1:
                        self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
                    elif len(seq_tensor) == longest_prefix:
                        # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
                        # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
                        ex_cache.current_seq_len -= 1

            if reset:
                ex_cache.current_seq_len = 0
                if len(seq_tensor) > 1:
                    self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)

            logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
        else:
            ex_cache.current_seq_len = 0
            logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()

        if is_negative:
            self.past_seq_negative = seq_tensor
        else:
            self.past_seq = seq_tensor

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
        assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)


        config = ExLlamaV2Config()
        config.model_dir = str(pretrained_model_name_or_path)
        config.prepare()

        config.max_seq_len = shared.args.max_seq_len
        config.scale_pos_emb = shared.args.compress_pos_emb
        config.scale_alpha_value = shared.args.alpha_value
        config.no_flash_attn = shared.args.no_flash_attn

        return Exllamav2HF(config)

class Shared:
    class Args:
        def __init__(self):
            self.gpu_split = None

    def __init__(self):
        self.args = Shared.Args()

shared = Shared()
shared.args.gpu_split = None
shared.args.cache_8bit = None
shared.args.cfg_cache = None
# shared.args.model_dir = "/kaggle/input/"
shared.args.max_seq_len = 2048
shared.args.compress_pos_emb = 1
shared.args.alpha_value = 1
shared.args.no_flash_attn = 1

shared = Shared()
shared.args.gpu_split = None
shared.args.cache_8bit = None
shared.args.cfg_cache = None
# shared.args.model_dir = "/kaggle/input/"
shared.args.max_seq_len = 2048
shared.args.compress_pos_emb = 1
shared.args.alpha_value = 1
shared.args.no_flash_attn = 1


# ------ Dataset related code ------

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
    return preds, labels

def get_rouge_score(y_true: List, y_pred: List):
    preds, gts = postprocess_text(preds=y_pred, labels=[item['summary'] for item in y_true])

    result = rouge_metric.compute(predictions=preds, references=gts, use_stemmer=True, use_aggregator=False)
    return np.mean(result['rouge2'])

class DatasetWrapper:
    def __init__(self, hf_dataset, tokenizer, prompt_template = "Summarize : {dialogue}", input_key = "", output_key = "", system_prompt = "", add_output = True):
        self.tokenizer = tokenizer
        self.hf_dataset = hf_dataset
        self.hf_dataset = self.hf_dataset.map(lambda x : {"chat_format" : ([{'role' : "system", "content" : system_prompt}] if system_prompt else []) + [
            {
                'role' : "user", "content" : prompt_template.format(**{input_key : x[input_key]})
            }
        ] + ([{'role' : 'assistant', "content" : x[output_key]}] if add_output else [])})

    def apply_chat_template(self, add_gen_prompt = True):
        """Converts the dataset to a chat based format"""
        self.hf_dataset = self.hf_dataset.map(lambda x: {"formatted_chat": self.tokenizer.apply_chat_template(x["chat_format"], tokenize=False, add_generation_prompt=add_gen_prompt)})



def load_model_and_tokenizer(model_id, temp_model_dir):
    temp_model_dir.mkdir(exist_ok=True, parents=True)
    output_folder = download_model_from_hf(model_id, save_dir=temp_model_dir, branch="main")

    gc_cuda()

    model = Exllamav2HF.from_pretrained(output_folder)
    # model = AutoGPTQForCausalLM.from_quantized(
    #     model_name_or_path = output_folder,
    #     device = device,
    # )
    tokenizer = AutoTokenizer.from_pretrained(
        output_folder, local_files_only=True
    )
    tokenizer.pad_token = tokenizer.unk_token

    return model, tokenizer

def load_dataset():

    # model was finetuned on val set

    train_dataset = datasets.load_dataset("samsum")['train']
    train_dataset = DatasetWrapper(train_dataset, tokenizer, input_key = "dialogue", output_key = "summary", add_output = False)
    train_dataset.apply_chat_template(add_gen_prompt=True)
    train_dataset.hf_dataset = train_dataset.hf_dataset.shuffle(seed=seed)

    samples_to_tune_on = train_dataset.hf_dataset.select(range(num_tune_samples))

    test_dataset = datasets.load_dataset("samsum")['test']
    test_dataset = DatasetWrapper(test_dataset, tokenizer, input_key = "dialogue", output_key = "summary", add_output = False)
    test_dataset.apply_chat_template(add_gen_prompt=True)
    test_dataset.hf_dataset = test_dataset.hf_dataset.shuffle(seed=seed)
    test_samples = test_dataset.hf_dataset.select(range(num_test_samples))

    return samples_to_tune_on, test_samples




In [8]:

# Load Model, Tokenizer, Dataset
temp_model_dir = Path(f"./temp_dir/")
temp_model_dir.mkdir(exist_ok=True, parents=True)

model, tokenizer = load_model_and_tokenizer(model_id, temp_model_dir)

# Dataset we will use to find the best generation parameters and test samples
samples_to_tune_on,test_dataset = load_dataset()

# create stop token criteria
multi_token_stop_criteria_ob = MultiTokenStoppingCriteria(sequence_ids=[32000])
stopping_criteria = StoppingCriteriaList([multi_token_stop_criteria_ob])
callbacks_after_inference = [multi_token_stop_criteria_ob.reset]

rouge_metric = evaluate.load("rouge")

Model already exists in temp_dir/Praful932_dolphin-2.2.1-mistral-7b-samsum-ft-v1-GPTQ. Checking the model files...


Checksum validated: model.safetensors  817eec4c0f73483e67516e28b499ab75f11a4639aad5ffa04c389f8f25ce2cf8
Checksum validated: tokenizer.model  dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
[+] Validated checksums of all model files!


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
tuner_ob = Tuner(
    model=model,
    tokenizer=tokenizer,
    dataset=samples_to_tune_on,
    device="cuda:0",
    batch_size=batch_size,
    tokenizer_encode_args={"padding": "longest",'truncation' : True, "add_special_tokens": False},
    tokenizer_decode_args={"spaces_between_special_tokens": False, 'skip_special_tokens' : True},
    scorer=get_rouge_score,
    prompt_template="{formatted_chat}",
    seed=seed,
    column_mapping={"input_cols": ["formatted_chat"], "eval_cols": ["summary"]},
    callbacks_after_inference=callbacks_after_inference,
)

In [7]:
print(tuner_ob.dataset['_X'][0])

NameError: name 'tuner_ob' is not defined

In [11]:

gen_params1 = {
    'max_new_tokens' : 70,
    'stopping_criteria' : stopping_criteria,
    'generation_seed' : 42,
    'temperature' : 0.1,
    'top_k' : 70,
    'no_repeat_ngram_size' : 0,
    'do_sample' : True,
}

s, o = tuner_ob.get_score(gen_params1)

100%|██████████| 1200/1200 [17:00<00:00,  1.18it/s]


In [12]:
print(s)

0.2475812543892885


In [None]:
gen_params1 = {
    'max_new_tokens' : 70,
    'stopping_criteria' : stopping_criteria,
    'generation_seed' : 42,
}

scores_before, outputs_before = tuner_ob.get_score(gen_params1)

In [9]:
print(scores_before)

0.24729290809820706


In [10]:
hyp_space = {
    'max_new_tokens' : [70],
    'stopping_criteria' : [stopping_criteria],
    'generation_seed' : [42],
    'do_sample' : [True],

    'temperature': [0.1,0.3,0.5,0.7,0.9,1.0],
    'top_k': [50,60,70,80],
    'no_repeat_ngram_size': [0],
}

clf = GridSearchCV(
    estimator = tuner_ob.estimator,
    param_grid=hyp_space,
    scoring = tuner_ob.scorer,
    cv = 2,
    n_jobs = None,
    verbose=3,
)

In [11]:
clf.fit(X=tuner_ob.dataset["_X"], y=tuner_ob.dataset['_y'])

Fitting 2 folds for each of 24 candidates, totalling 48 fits


  8%|▊         | 45/600 [00:35<06:16,  1.48it/s]

100%|██████████| 600/600 [08:02<00:00,  1.24it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=50;, score=0.253 total time= 8.1min


100%|██████████| 600/600 [07:55<00:00,  1.26it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=50;, score=0.242 total time= 8.0min


100%|██████████| 600/600 [07:51<00:00,  1.27it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=60;, score=0.252 total time= 7.9min


100%|██████████| 600/600 [07:51<00:00,  1.27it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=60;, score=0.240 total time= 7.9min


100%|██████████| 600/600 [07:46<00:00,  1.29it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=70;, score=0.247 total time= 7.8min


100%|██████████| 600/600 [07:49<00:00,  1.28it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=70;, score=0.239 total time= 7.9min


100%|██████████| 600/600 [07:53<00:00,  1.27it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=80;, score=0.255 total time= 7.9min


100%|██████████| 600/600 [07:47<00:00,  1.28it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.1, top_k=80;, score=0.233 total time= 7.8min


100%|██████████| 600/600 [07:58<00:00,  1.25it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=50;, score=0.242 total time= 8.0min


100%|██████████| 600/600 [07:54<00:00,  1.27it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=50;, score=0.234 total time= 7.9min


100%|██████████| 600/600 [08:10<00:00,  1.22it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=60;, score=0.251 total time= 8.2min


100%|██████████| 600/600 [07:51<00:00,  1.27it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=60;, score=0.236 total time= 7.9min


100%|██████████| 600/600 [07:52<00:00,  1.27it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=70;, score=0.243 total time= 7.9min


100%|██████████| 600/600 [07:38<00:00,  1.31it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=70;, score=0.228 total time= 7.7min


100%|██████████| 600/600 [07:55<00:00,  1.26it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=80;, score=0.238 total time= 7.9min


100%|██████████| 600/600 [07:38<00:00,  1.31it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.3, top_k=80;, score=0.236 total time= 7.7min


100%|██████████| 600/600 [07:58<00:00,  1.25it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=50;, score=0.221 total time= 8.0min


100%|██████████| 600/600 [07:45<00:00,  1.29it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=50;, score=0.209 total time= 7.8min


100%|██████████| 600/600 [08:09<00:00,  1.23it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=60;, score=0.227 total time= 8.2min


100%|██████████| 600/600 [07:52<00:00,  1.27it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=60;, score=0.217 total time= 7.9min


100%|██████████| 600/600 [08:13<00:00,  1.22it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=70;, score=0.221 total time= 8.2min


100%|██████████| 600/600 [07:56<00:00,  1.26it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=70;, score=0.215 total time= 8.0min


100%|██████████| 600/600 [07:50<00:00,  1.28it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=80;, score=0.226 total time= 7.9min


100%|██████████| 600/600 [08:05<00:00,  1.24it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.5, top_k=80;, score=0.218 total time= 8.1min


100%|██████████| 600/600 [08:22<00:00,  1.19it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=50;, score=0.206 total time= 8.4min


100%|██████████| 600/600 [07:51<00:00,  1.27it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=50;, score=0.184 total time= 7.9min


100%|██████████| 600/600 [08:06<00:00,  1.23it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=60;, score=0.199 total time= 8.1min


100%|██████████| 600/600 [08:10<00:00,  1.22it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=60;, score=0.190 total time= 8.2min


100%|██████████| 600/600 [08:32<00:00,  1.17it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=70;, score=0.200 total time= 8.6min


100%|██████████| 600/600 [08:03<00:00,  1.24it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=70;, score=0.192 total time= 8.1min


100%|██████████| 600/600 [08:22<00:00,  1.20it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=80;, score=0.200 total time= 8.4min


100%|██████████| 600/600 [08:02<00:00,  1.24it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.7, top_k=80;, score=0.190 total time= 8.1min


100%|██████████| 600/600 [08:32<00:00,  1.17it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=50;, score=0.168 total time= 8.6min


100%|██████████| 600/600 [08:18<00:00,  1.20it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=50;, score=0.164 total time= 8.3min


100%|██████████| 600/600 [08:27<00:00,  1.18it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=60;, score=0.163 total time= 8.5min


100%|██████████| 600/600 [08:14<00:00,  1.21it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=60;, score=0.157 total time= 8.3min


100%|██████████| 600/600 [08:21<00:00,  1.20it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=70;, score=0.152 total time= 8.4min


100%|██████████| 600/600 [08:33<00:00,  1.17it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=70;, score=0.150 total time= 8.6min


100%|██████████| 600/600 [08:36<00:00,  1.16it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=80;, score=0.152 total time= 8.6min


100%|██████████| 600/600 [08:27<00:00,  1.18it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=0.9, top_k=80;, score=0.162 total time= 8.5min


100%|██████████| 600/600 [08:31<00:00,  1.17it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=50;, score=0.138 total time= 8.5min


100%|██████████| 600/600 [08:10<00:00,  1.22it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=50;, score=0.141 total time= 8.2min


100%|██████████| 600/600 [08:45<00:00,  1.14it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=60;, score=0.143 total time= 8.8min


100%|██████████| 600/600 [08:25<00:00,  1.19it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=60;, score=0.136 total time= 8.5min


100%|██████████| 600/600 [08:41<00:00,  1.15it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=70;, score=0.138 total time= 8.7min


100%|██████████| 600/600 [08:31<00:00,  1.17it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=70;, score=0.143 total time= 8.5min


100%|██████████| 600/600 [08:33<00:00,  1.17it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=80;, score=0.138 total time= 8.6min


100%|██████████| 600/600 [08:38<00:00,  1.16it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], temperature=1.0, top_k=80;, score=0.127 total time= 8.7min


In [12]:
scores_after, outputs_after = tuner_ob.get_score(clf.best_params_)

  2%|▏         | 23/1200 [00:19<15:47,  1.24it/s]

100%|██████████| 1200/1200 [15:39<00:00,  1.28it/s]


In [None]:
print(scores_before, scores_after)

In [1]:
1

1

In [14]:
str(clf.best_params_)

"{'do_sample': True, 'generation_seed': 42, 'max_new_tokens': 70, 'no_repeat_ngram_size': 0, 'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], 'temperature': 0.1, 'top_k': 50}"

In [15]:
from llmsearch.utils.common_utils import json_load, json_dump

In [16]:

d = {
    'scores_before' : scores_before,
    'scores_after' : scores_after,
    'outputs_before' : outputs_before,
    'outputs_after' : outputs_after,
    'best_params' : str(clf.best_params_),
}

f = "./samsum-best-params-1200s-capybara-7b.json"
json_dump(d, f)

In [17]:
d = json_load("./samsum-best-params-1200s-capybara-7b.json")

In [18]:
print(d['best_params'])

{'do_sample': True, 'generation_seed': 42, 'max_new_tokens': 70, 'no_repeat_ngram_size': 0, 'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7f26eb8db160>], 'temperature': 0.1, 'top_k': 50}


In [19]:
# harcoding from above file here due to notebook re-run
len(test_dataset)

800

In [20]:
# eval on test samples

gen_params1 = {
    "max_new_tokens": 70,
    "stopping_criteria": stopping_criteria,
    "generation_seed": 42,
}

oos_scores_before, oos_outputs_before = tuner_ob.get_score(gen_params1,test_dataset)

: 

In [None]:
test_dataset[1]

{'id': '13681165-1',
 'dialogue': "Alyssa: Have you seen Fergie’s national anthem? Illuminati does a great job.\r\nDerek: This is not normal. I saw it last week…\r\nAlyssa: What do you think about it?\r\nDerek: I can fart bright stripes and bright stars better then she sings.\r\nAlyssa: The best part is that she acts like she nailed it. But at least it's funny in a good way.\r\nDerek: It is 😂",
 'summary': "Derek and Alyssa make fun of Fergie's performance of the national anthem.",
 'chat_format': [{'content': "Summarize : Alyssa: Have you seen Fergie’s national anthem? Illuminati does a great job.\r\nDerek: This is not normal. I saw it last week…\r\nAlyssa: What do you think about it?\r\nDerek: I can fart bright stripes and bright stars better then she sings.\r\nAlyssa: The best part is that she acts like she nailed it. But at least it's funny in a good way.\r\nDerek: It is 😂",
   'role': 'user'}],
 'formatted_chat': "<|im_start|>user\nSummarize : Alyssa: Have you seen Fergie’s nation

In [None]:
oos_scores_after, oos_outputs_after = tuner_ob.get_score(clf.best_params_,test_dataset)

  0%|          | 0/200 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  0%|          | 1/200 [00:05<17:12,  5.19s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  1%|          | 2/200 [00:12<21:07,  6.40s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  2%|▏         | 3/200 [00:31<39:25, 12.01s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  2%|▏         | 4/200 [00:51<49:45, 15.23s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  2%|▎         | 5/200 [00:59<41:20, 12.72s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  3%|▎         | 6/200 [01:02<30:21,  9.39s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  4%|▎         | 7/200 [01:06<24:24,  7.59s/it]Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
  4%|▍         | 8/200 [01:26<36:44, 11.48s/it]Setting `

In [None]:
print(oos_scores_before, oos_scores_after)

0.2540499441403945 0.24443739584141397


In [None]:
d = {
    'scores_before' : scores_before,
    'scores_after' : scores_after,
    'outputs_before' : outputs_before,
    'outputs_after' : outputs_after,

    'oos_scores_before' : oos_scores_before,
    'oos_scores_after' : oos_scores_after,
    'oos_outputs_before' : oos_outputs_before,
    'oos_outputs_after' : oos_outputs_after,
    'best_params' : str(clf.best_params_),
}

f = "./samsum-best-params-1200s-tune-capybara-7b.json"
json_dump(d, f)