In [1]:
import gc
import time
import random

from typing import List, Any, Dict, Union, Tuple

import nltk
import torch

import numpy as np
from tqdm.notebook import tqdm


import datasets
import evaluate
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [2]:
seed = 42

dataset = datasets.load_dataset("samsum")
rouge_metric  = evaluate.load("rouge")

device = "cpu"

if torch.backends.mps.is_built() and torch.backends.mps.is_available():
    device = "mps"

def seed_everything(seed):
    """Seed for reproducibilty"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(seed=seed)

Found cached dataset samsum (/Users/praful932/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
print(torch.backends.mps.is_built(), torch.backends.mps.is_available())

True True


In [4]:
model_id = "google/flan-t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)

In [5]:
dataset['validation'][:100].keys()

dict_keys(['id', 'dialogue', 'summary'])

In [6]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def get_rouge_score(y_true : List, y_pred : List):
    preds, gts = postprocess_text(preds = y_pred, labels = y_true)

    result = rouge_metric.compute(predictions=preds, references=gts, use_stemmer=True)
    
    return result

def infer_single_sample(model, input_text : str, model_input_tokenizer_kwargs : Dict, generation_kwargs : Dict):
    gc.collect()
    start = time.time()
    input_ids = tokenizer(text = input_text, **model_input_tokenizer_kwargs, return_tensors="pt").input_ids.to(device)
    output_ids = model.generate(input_ids, **generation_kwargs)[0]
    decoded_output = tokenizer.decode(token_ids=output_ids, skip_special_tokens=True)
    end = time.time()
    latency = (end - start) * 1000
    return decoded_output, latency
    

In [7]:
def get_metrics(model, prompt, dataset, model_input_tokenizer_kwargs, generation_kwargs):
    preds = []
    gts = []

    for item in tqdm(samples_to_tune_on):
        model_input = prompt.format(input_text = item['dialogue'])
        output, _ = infer_single_sample(model, model_input, model_input_tokenizer_kwargs, generation_kwargs)
        preds.append(output)
        gts.append(item['summary'])
        # print(f"Model Output : {output}")
        # print(f"Label : {item['summary']}")
        # print()
        # print("===" * 10)

    metric = get_rouge_score(y_true=gts, y_pred=preds)
    return metric, preds, gts

# take 1st 100 samples
sample_size = 2
samples_to_tune_on = datasets.Dataset.from_dict(dataset['validation'][:sample_size])

model_input_tokenizer_kwargs = {
    'max_length' : 1000,
    'truncation' : True,
    'padding' : False,
}
prompt = "{input_text} Summarize:"

generation_kwargs = {
    'temperature' : 1,
}

# Check metric on these examples before any generation param tuning

metric, _, _ = get_metrics(model=model, prompt=prompt, dataset=samples_to_tune_on, model_input_tokenizer_kwargs=model_input_tokenizer_kwargs, generation_kwargs = generation_kwargs)
print(metric)

  0%|          | 0/2 [00:00<?, ?it/s]

  if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):


{'rouge1': 0.14634146341463414, 'rouge2': 0.025641025641025637, 'rougeL': 0.12195121951219515, 'rougeLsum': 0.12195121951219515}


In [26]:
import warnings
from collections import defaultdict

from typing import Dict
from sklearn.utils.estimator_checks import check_estimator
from sklearn.model_selection import RandomizedSearchCV

from sklearn.base import BaseEstimator
from sklearn.metrics import make_scorer

model_input_tokenizer_kwargs = {
    'max_length' : 1000,
    'truncation' : True,
    'padding' : 'max_length',
}

def get_rouge_score(y_true : List, y_pred : List):
    preds, gts = postprocess_text(preds = y_pred, labels = y_true)

    result = rouge_metric.compute(predictions=preds, references=gts, use_stemmer=True)
    
    return result['rouge2']

def infer_single_sample(model, input_text : str, model_input_tokenizer_kwargs : Dict, generation_kwargs : Dict):
    gc.collect()
    start = time.time()
    input_ids = tokenizer(text = input_text, **model_input_tokenizer_kwargs, return_tensors="pt").input_ids.to(device)
    output_ids = model.generate(input_ids, **generation_kwargs)
    decoded_output = tokenizer.batch_decode(sequences=output_ids, skip_special_tokens=True)
    end = time.time()
    latency = (end - start) * 1000
    return decoded_output, latency

class EstimatorWrapper(BaseEstimator):
    def __init__(self, model, **kwargs):
        self.model = model
        for k, v in kwargs.items():
            self.__setattr__(k, v)

    def fit(self, X, y = None):
        self.is_fitted_ = True
        return self

    def predict(self, X):
        model_generation_params = {attr: self.__getattribute__(attr) for attr in self.model_generation_param_keys}
        output, _ = infer_single_sample(model = self.model, input_text=X,model_input_tokenizer_kwargs=model_input_tokenizer_kwargs,generation_kwargs=model_generation_params)
        return output
    
    def set_params(self, **params):
        """Set the parameters of this estimator.

        The method works on simple estimators as well as on nested objects
        (such as :class:`~sklearn.pipeline.Pipeline`). The latter have
        parameters of the form ``<component>__<parameter>`` so that it's
        possible to update each component of a nested object.

        Parameters
        ----------
        **params : dict
            Estimator parameters.

        Returns
        -------
        self : estimator instance
            Estimator instance.
        """
        if not params:
            # Simple optimization to gain speed (inspect is slow)
            return self
        valid_params = self.get_params(deep=True)

        nested_params = defaultdict(dict)  # grouped by prefix
        self.model_generation_param_keys = params.keys()
        for key, value in params.items():
            key, delim, sub_key = key.partition("__")
            if delim:
                nested_params[key][sub_key] = value
            else:
                setattr(self, key, value)
                valid_params[key] = value

        for key, sub_params in nested_params.items():
            # TODO(1.4): remove specific handling of "base_estimator".
            # The "base_estimator" key is special. It was deprecated and
            # renamed to "estimator" for several estimators. This means we
            # need to translate it here and set sub-parameters on "estimator",
            # but only if the user did not explicitly set a value for
            # "base_estimator".
            if (
                key == "base_estimator"
                and valid_params[key] == "deprecated"
                and self.__module__.startswith("sklearn.")
            ):
                warnings.warn(
                    (
                        f"Parameter 'base_estimator' of {self.__class__.__name__} is"
                        " deprecated in favor of 'estimator'. See"
                        f" {self.__class__.__name__}'s docstring for more details."
                    ),
                    FutureWarning,
                    stacklevel=2,
                )
                key = "estimator"
           
            valid_params[key].set_params(**sub_params)

        return self

# def get_best_hyperparams(model, dataset, hyp_param_grid : Dict, metric : function, output_extraction_function = None, model_generation_attribute = "generate"):
    

# take 1st 100 samples
samples_to_tune_on = dataset['validation'][:100]
hyp_param_grid = {
    'min_new_tokens' : [5],
    'max_new_tokens' : [100],
    'num_beams' : [1,2,3],
    'temperature' : [0.3,0.7,1],
    'epsilon_cutoff' : [3e-4,6e-4,9e-4,0],
    'repetition_penalty' : [0.3,0.7,1.0],
    'no_repeat_ngram_size' : [0,2,3,4],
}
model_estimator = EstimatorWrapper(model=model) 
scorer = make_scorer(score_func=get_rouge_score, greater_is_better = True)

clf = RandomizedSearchCV(estimator=model_estimator,param_distributions=hyp_param_grid,n_iter=10, scoring = scorer, cv = 2,random_state=seed, n_jobs=1, verbose=1)



In [27]:
clf.fit(X = samples_to_tune_on['dialogue'], y=samples_to_tune_on['summary'])

Fitting 2 folds for each of 10 candidates, totalling 20 fits


KeyboardInterrupt: 

In [25]:
clf.best_params_

{'temperature': 0.3,
 'repetition_penalty': 1.0,
 'num_beams': 1,
 'no_repeat_ngram_size': 4,
 'min_new_tokens': 5,
 'max_new_tokens': 100,
 'epsilon_cutoff': 0.0006}

In [24]:
clf.cv_results_

{'mean_fit_time': array([6.00814819e-05, 4.35113907e-05, 3.86238098e-05, 3.91006470e-05,
        4.76837158e-05, 3.94582748e-05, 4.55379486e-05, 3.80277634e-05,
        4.25577164e-05, 4.22000885e-05]),
 'std_fit_time': array([1.50203705e-05, 3.57627869e-07, 1.43051147e-06, 2.86102295e-06,
        6.43730164e-06, 5.96046448e-07, 9.53674316e-06, 1.19209290e-07,
        1.54972076e-06, 1.90734863e-06]),
 'mean_score_time': array([4.32253551, 1.69391119, 1.93649256, 0.85784805, 1.54696596,
        0.80367553, 2.95882702, 2.02347541, 1.12767506, 1.57333779]),
 'std_score_time': array([2.13820338, 0.10636604, 0.09814036, 0.09990704, 0.12719214,
        0.10669839, 1.3027761 , 0.49564958, 0.19829798, 0.11063814]),
 'param_temperature': masked_array(data=[0.7, 0.3, 0.3, 0.3, 1, 1, 0.3, 0.7, 0.3, 0.7],
              mask=[False, False, False, False, False, False, False, False,
                    False, False],
        fill_value='?',
             dtype=object),
 'param_repetition_penalty': ma