In [1]:
"""
Requires nltk==3.8.1
rouge_score==0.1.2
py7zr=0.21.0
"""

import torch
import transformers

import llmsearch
import exllamav2
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Config
)

print(exllamav2.__version__, torch.__version__, transformers.__version__, llmsearch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


Monkey Patching .generate function of `transformers` library
0.0.14 2.2.0+cu121 4.38.2 0.1.0


In [2]:
import os
from pathlib import Path
from typing import Dict, Any, Optional, Union, List

import nltk
import datasets
import evaluate
import numpy as np

from llmsearch.tuner import Tuner
from llmsearch.utils.mem_utils import gc_cuda
from sklearn.model_selection import GridSearchCV
from llmsearch.utils.model_downloader import download_model_from_hf
from llmsearch.scripts.stopping_criteria import MultiTokenStoppingCriteria
from llmsearch.utils.logging_utils import set_verbosity_info, set_verbosity_debug, set_verbosity_warning
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, AutoTokenizer, StoppingCriteriaList




In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
seed = 42
batch_size = 1
num_tune_samples = 500
num_test_samples = 200

model_id = "Praful932/dolphin-2.2.1-mistral-7b-samsum-ft-v1-GPTQ"
device = "cuda:0"

In [5]:
# ------ Model related code ------
class Exllamav2HF(PreTrainedModel):
    def __init__(self, config: ExLlamaV2Config):
        super().__init__(PretrainedConfig())
        self.ex_config = config
        self.ex_model = ExLlamaV2(config)
        split = None
        if shared.args.gpu_split:
            split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

        self.ex_model.load(split)
        self.generation_config = GenerationConfig()
        self.loras = None

        if shared.args.cache_8bit:
            self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
        else:
            self.ex_cache = ExLlamaV2Cache(self.ex_model)

        self.past_seq = None
        if shared.args.cfg_cache:
            if shared.args.cache_8bit:
                self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
            else:
                self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)

            self.past_seq_negative = None

    def _validate_model_class(self):
        pass

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        pass

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}

    @property
    def device(self) -> torch.device:
        return torch.device(0)

    def __call__(self, *args, **kwargs):
        use_cache = kwargs.get('use_cache', True)
        labels = kwargs.get('labels', None)
        past_key_values = kwargs.get('past_key_values', None)

        if len(args) > 0:
            if not shared.args.cfg_cache:
                print("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
                return

            input_ids = args[0]
            is_negative = True
            past_seq = self.past_seq_negative
            ex_cache = self.ex_cache_negative
        else:
            input_ids = kwargs['input_ids']
            is_negative = False
            past_seq = self.past_seq
            ex_cache = self.ex_cache

        seq = input_ids[0].tolist()
        if is_negative and past_key_values is not None:
            seq = past_key_values + seq

        seq_tensor = torch.tensor(seq)
        reset = True

        # Make the forward call
        if labels is None:
            if past_seq is not None:
                min_length = min(past_seq.shape[0], seq_tensor.shape[0])
                indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
                if len(indices) > 0:
                    longest_prefix = indices[0].item()
                else:
                    longest_prefix = min_length

                if longest_prefix > 0:
                    reset = False
                    ex_cache.current_seq_len = longest_prefix
                    if len(seq_tensor) - longest_prefix > 1:
                        self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
                    elif len(seq_tensor) == longest_prefix:
                        # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
                        # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
                        ex_cache.current_seq_len -= 1

            if reset:
                ex_cache.current_seq_len = 0
                if len(seq_tensor) > 1:
                    self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)

            logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
        else:
            ex_cache.current_seq_len = 0
            logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()

        if is_negative:
            self.past_seq_negative = seq_tensor
        else:
            self.past_seq = seq_tensor

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
        assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)


        config = ExLlamaV2Config()
        config.model_dir = str(pretrained_model_name_or_path)
        config.prepare()

        config.max_seq_len = shared.args.max_seq_len
        config.scale_pos_emb = shared.args.compress_pos_emb
        config.scale_alpha_value = shared.args.alpha_value
        config.no_flash_attn = shared.args.no_flash_attn

        return Exllamav2HF(config)

class Shared:
    class Args:
        def __init__(self):
            self.gpu_split = None

    def __init__(self):
        self.args = Shared.Args()

shared = Shared()
shared.args.gpu_split = None
shared.args.cache_8bit = None
shared.args.cfg_cache = None
# shared.args.model_dir = "/kaggle/input/"
shared.args.max_seq_len = 2048
shared.args.compress_pos_emb = 1
shared.args.alpha_value = 1
shared.args.no_flash_attn = 1

shared = Shared()
shared.args.gpu_split = None
shared.args.cache_8bit = None
shared.args.cfg_cache = None
# shared.args.model_dir = "/kaggle/input/"
shared.args.max_seq_len = 2048
shared.args.compress_pos_emb = 1
shared.args.alpha_value = 1
shared.args.no_flash_attn = 1


# ------ Dataset related code ------

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
    return preds, labels

def get_rouge_score(y_true: List, y_pred: List):
    preds, gts = postprocess_text(preds=y_pred, labels=[item['summary'] for item in y_true])

    result = rouge_metric.compute(predictions=preds, references=gts, use_stemmer=True, use_aggregator=False)
    return np.mean(result['rouge2'])

class DatasetWrapper:
    def __init__(self, hf_dataset, tokenizer, prompt_template = "Summarize : {dialogue}", input_key = "", output_key = "", system_prompt = "", add_output = True):
        self.tokenizer = tokenizer
        self.hf_dataset = hf_dataset
        self.hf_dataset = self.hf_dataset.map(lambda x : {"chat_format" : ([{'role' : "system", "content" : system_prompt}] if system_prompt else []) + [
            {
                'role' : "user", "content" : prompt_template.format(**{input_key : x[input_key]})
            }
        ] + ([{'role' : 'assistant', "content" : x[output_key]}] if add_output else [])})

    def apply_chat_template(self, add_gen_prompt = True):
        """Converts the dataset to a chat based format"""
        self.hf_dataset = self.hf_dataset.map(lambda x: {"formatted_chat": self.tokenizer.apply_chat_template(x["chat_format"], tokenize=False, add_generation_prompt=add_gen_prompt)})



def load_model_and_tokenizer(model_id, temp_model_dir):
    temp_model_dir.mkdir(exist_ok=True, parents=True)
    output_folder = download_model_from_hf(model_id, save_dir=temp_model_dir, branch="main")

    gc_cuda()

    model = Exllamav2HF.from_pretrained(output_folder)
    tokenizer = AutoTokenizer.from_pretrained(
        output_folder, local_files_only=True
    )
    tokenizer.pad_token = tokenizer.unk_token

    return model, tokenizer

def load_dataset():

    # since model was finetuned on val set

    # valid_dataset = datasets.load_dataset("samsum")['validation']

    # valid_dataset = DatasetWrapper(valid_dataset, tokenizer, input_key = "dialogue", output_key = "summary", add_output = False)
    # valid_dataset.apply_chat_template(add_gen_prompt=True)
    # valid_dataset.hf_dataset = valid_dataset.hf_dataset.shuffle(seed=seed)
    # samples_to_tune_on = valid_dataset.hf_dataset.select(range(num_tune_samples))


    test_dataset = datasets.load_dataset("samsum")['test']
    test_dataset = DatasetWrapper(test_dataset, tokenizer, input_key = "dialogue", output_key = "summary", add_output = False)
    test_dataset.apply_chat_template(add_gen_prompt=True)
    test_dataset.hf_dataset = test_dataset.hf_dataset.shuffle(seed=seed)

    samples_to_tune_on = test_dataset.hf_dataset.select(range(num_tune_samples))
    remaining_indices = range(num_tune_samples, num_tune_samples + num_test_samples)
    test_samples = test_dataset.hf_dataset.select(remaining_indices)

    return samples_to_tune_on, test_samples




In [6]:
os.environ['HF_TOKEN'] = "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"

# Load Model, Tokenizer, Dataset
temp_model_dir = Path(f"./temp_dir/")
temp_model_dir.mkdir(exist_ok=True, parents=True)

model, tokenizer = load_model_and_tokenizer(model_id, temp_model_dir)

# Dataset we will use to find the best generation parameters
samples_to_tune_on,test_samples = load_dataset()

multi_token_stop_criteria_ob = MultiTokenStoppingCriteria(sequence_ids=[32000])
stopping_criteria = StoppingCriteriaList([multi_token_stop_criteria_ob])
callbacks_after_inference = [multi_token_stop_criteria_ob.reset]

rouge_metric = evaluate.load("rouge")

Model already exists in temp_dir/Praful932_dolphin-2.2.1-mistral-7b-samsum-ft-v1-GPTQ. Checking the model files...


Checksum validated: model.safetensors  817eec4c0f73483e67516e28b499ab75f11a4639aad5ffa04c389f8f25ce2cf8
Checksum validated: tokenizer.model  dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
[+] Validated checksums of all model files!


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# set_verbosity_debug()
set_verbosity_warning()

In [8]:
tuner_ob = Tuner(
    model=model,
    tokenizer=tokenizer,
    dataset=samples_to_tune_on,
    device="cuda:0",
    batch_size=batch_size,
    tokenizer_encode_args={"padding": "longest",'truncation' : True, "add_special_tokens": False},
    tokenizer_decode_args={"spaces_between_special_tokens": False, 'skip_special_tokens' : True},
    scorer=get_rouge_score,
    prompt_template="{formatted_chat}",
    seed=seed,
    column_mapping={"input_cols": ["formatted_chat"], "eval_cols": ["summary"]},
    callbacks_after_inference=callbacks_after_inference,
)

In [9]:
print(tuner_ob.dataset['_X'][0])

<|im_start|>user
Summarize : Claire: <file_photo>
Kim: Looks delicious...
Linda: No way... Look what I'm cooking right now:
Linda: <file_photo>
Claire: hahahaha 
Kim: Curry dream team
Claire: Enjoy your dinner :*<|im_end|>
<|im_start|>assistant



In [10]:
max_new_tokens = 70

gen_params1 = {
    'max_new_tokens' : 70,
    'stopping_criteria' : stopping_criteria,
    'generation_seed' : 42,
}

scores_before, outputs_before = tuner_ob.get_score(gen_params1)

100%|██████████| 500/500 [07:15<00:00,  1.15it/s]


In [11]:
print(scores_before)

0.257852927825895


In [13]:
hyp_space = {
    'max_new_tokens' : [70],
    'stopping_criteria' : [stopping_criteria],
    'generation_seed' : [42],
    'do_sample' : [True],

    'temperature': [0.1,0.3,0.5,0.7,0.9,1.0],
    'top_k': [50,60,70,80],
    'no_repeat_ngram_size': [0],
}

clf = GridSearchCV(
    estimator = tuner_ob.estimator,
    param_grid=hyp_space,
    scoring = tuner_ob.scorer,
    cv = 2,
    n_jobs = None,
    verbose=3,
)

In [15]:
clf.fit(X=tuner_ob.dataset["_X"], y=tuner_ob.dataset['_y'])

Fitting 2 folds for each of 24 candidates, totalling 48 fits


100%|██████████| 250/250 [04:01<00:00,  1.03it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=50;, score=0.256 total time= 4.0min


100%|██████████| 250/250 [03:49<00:00,  1.09it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=50;, score=0.249 total time= 3.8min


100%|██████████| 250/250 [04:16<00:00,  1.02s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=60;, score=0.259 total time= 4.3min


100%|██████████| 250/250 [03:55<00:00,  1.06it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=60;, score=0.254 total time= 3.9min


100%|██████████| 250/250 [04:11<00:00,  1.01s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=70;, score=0.257 total time= 4.2min


100%|██████████| 250/250 [04:01<00:00,  1.03it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=70;, score=0.257 total time= 4.0min


100%|██████████| 250/250 [04:15<00:00,  1.02s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=80;, score=0.258 total time= 4.3min


100%|██████████| 250/250 [03:55<00:00,  1.06it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.1, top_k=80;, score=0.254 total time= 3.9min


100%|██████████| 250/250 [04:14<00:00,  1.02s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=50;, score=0.250 total time= 4.3min


100%|██████████| 250/250 [04:00<00:00,  1.04it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=50;, score=0.254 total time= 4.0min


100%|██████████| 250/250 [04:22<00:00,  1.05s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=60;, score=0.250 total time= 4.4min


100%|██████████| 250/250 [03:54<00:00,  1.07it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=60;, score=0.237 total time= 3.9min


100%|██████████| 250/250 [03:58<00:00,  1.05it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=70;, score=0.247 total time= 4.0min


100%|██████████| 250/250 [03:50<00:00,  1.08it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=70;, score=0.251 total time= 3.9min


100%|██████████| 250/250 [04:05<00:00,  1.02it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=80;, score=0.253 total time= 4.1min


100%|██████████| 250/250 [03:47<00:00,  1.10it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.3, top_k=80;, score=0.242 total time= 3.8min


100%|██████████| 250/250 [04:02<00:00,  1.03it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=50;, score=0.217 total time= 4.1min


100%|██████████| 250/250 [03:53<00:00,  1.07it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=50;, score=0.224 total time= 3.9min


100%|██████████| 250/250 [04:06<00:00,  1.01it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=60;, score=0.229 total time= 4.1min


100%|██████████| 250/250 [03:50<00:00,  1.08it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=60;, score=0.224 total time= 3.9min


100%|██████████| 250/250 [03:56<00:00,  1.06it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=70;, score=0.227 total time= 4.0min


100%|██████████| 250/250 [03:48<00:00,  1.10it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=70;, score=0.230 total time= 3.8min


100%|██████████| 250/250 [03:57<00:00,  1.05it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=80;, score=0.225 total time= 4.0min


100%|██████████| 250/250 [03:51<00:00,  1.08it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.5, top_k=80;, score=0.236 total time= 3.9min


100%|██████████| 250/250 [04:02<00:00,  1.03it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=50;, score=0.193 total time= 4.0min


100%|██████████| 250/250 [03:47<00:00,  1.10it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=50;, score=0.218 total time= 3.8min


100%|██████████| 250/250 [04:08<00:00,  1.01it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=60;, score=0.205 total time= 4.1min


100%|██████████| 250/250 [03:45<00:00,  1.11it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=60;, score=0.197 total time= 3.8min


100%|██████████| 250/250 [04:01<00:00,  1.04it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=70;, score=0.194 total time= 4.0min


100%|██████████| 250/250 [03:55<00:00,  1.06it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=70;, score=0.207 total time= 3.9min


100%|██████████| 250/250 [03:59<00:00,  1.04it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=80;, score=0.183 total time= 4.0min


100%|██████████| 250/250 [03:48<00:00,  1.09it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.7, top_k=80;, score=0.209 total time= 3.8min


100%|██████████| 250/250 [04:04<00:00,  1.02it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=50;, score=0.169 total time= 4.1min


100%|██████████| 250/250 [04:02<00:00,  1.03it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=50;, score=0.159 total time= 4.0min


100%|██████████| 250/250 [04:07<00:00,  1.01it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=60;, score=0.148 total time= 4.1min


100%|██████████| 250/250 [03:54<00:00,  1.07it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=60;, score=0.162 total time= 3.9min


100%|██████████| 250/250 [04:16<00:00,  1.02s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=70;, score=0.148 total time= 4.3min


100%|██████████| 250/250 [04:03<00:00,  1.03it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=70;, score=0.168 total time= 4.1min


100%|██████████| 250/250 [04:11<00:00,  1.01s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=80;, score=0.152 total time= 4.2min


100%|██████████| 250/250 [03:52<00:00,  1.08it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=0.9, top_k=80;, score=0.165 total time= 3.9min


100%|██████████| 250/250 [04:18<00:00,  1.04s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=50;, score=0.142 total time= 4.3min


100%|██████████| 250/250 [04:03<00:00,  1.03it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=50;, score=0.145 total time= 4.1min


100%|██████████| 250/250 [04:05<00:00,  1.02it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=60;, score=0.135 total time= 4.1min


100%|██████████| 250/250 [03:58<00:00,  1.05it/s]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=60;, score=0.130 total time= 4.0min


100%|██████████| 250/250 [04:08<00:00,  1.00it/s]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=70;, score=0.132 total time= 4.2min


100%|██████████| 250/250 [04:13<00:00,  1.01s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=70;, score=0.131 total time= 4.2min


100%|██████████| 250/250 [04:12<00:00,  1.01s/it]


[CV 1/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=80;, score=0.135 total time= 4.2min


100%|██████████| 250/250 [04:10<00:00,  1.00s/it]


[CV 2/2] END do_sample=True, generation_seed=42, max_new_tokens=70, no_repeat_ngram_size=0, stopping_criteria=[<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], temperature=1.0, top_k=80;, score=0.146 total time= 4.2min


In [16]:
scores_after, outputs_after = tuner_ob.get_score(clf.best_params_)

100%|██████████| 500/500 [07:58<00:00,  1.05it/s]


In [20]:
print(scores_before, scores_after)

0.257852927825895 0.2581935494075166


In [21]:
str(clf.best_params_)

"{'do_sample': True, 'generation_seed': 42, 'max_new_tokens': 70, 'no_repeat_ngram_size': 0, 'stopping_criteria': [<llmsearch.scripts.stopping_criteria.MultiTokenStoppingCriteria object at 0x7fe3468afa90>], 'temperature': 0.1, 'top_k': 70}"

In [23]:
import json

def json_dump(ob : dict, file_path: Path):
    with open(file_path, 'w', encoding="utf-8") as json_file:
        json.dump(ob, json_file, indent=4)

d = {
    'scores_before' : scores_before,
    'scores_after' : scores_after,
    'outputs_before' : outputs_before,
    'outputs_after' : outputs_after,
    'best_params' : str(clf.best_params_),
}

f = "./samsum-best-params-500-capybara-7b.json"
json_dump(d, f)

In [None]:
# TODO eval on oos set