From 4440f87722ca9ae81e9d6123ed4b265ca2d4dae6 Mon Sep 17 00:00:00 2001 From: tdrussell <6509934+tdrussell@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:28:07 -0500 Subject: [PATCH] Add additive_repetition_penalty sampler setting. (#3627) --- api-examples/api-example-chat-stream.py | 1 + api-examples/api-example-chat.py | 1 + api-examples/api-example-stream.py | 1 + api-examples/api-example.py | 1 + "docs/03 \342\200\220 Parameters Tab.md" | 1 + extensions/api/util.py | 1 + extensions/openai/defaults.py | 1 + modules/loaders.py | 7 +++++++ modules/presets.py | 1 + modules/sampler_hijack.py | 23 ++++++++++++++++------- modules/text_generation.py | 2 +- modules/ui.py | 1 + modules/ui_parameters.py | 1 + 13 files changed, 34 insertions(+), 8 deletions(-) diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index bfa5d4f580..31bd120cea 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -52,6 +52,7 @@ async def run(user_input, history): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index b2a1e1e42b..e7c0ae7d78 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -46,6 +46,7 @@ def run(user_input, history): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 966ca6f62d..ad907196de 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -35,6 +35,7 @@ async def run(context): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/api-examples/api-example.py b/api-examples/api-example.py index d9fd60d05c..2f0267f294 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -27,6 +27,7 @@ def run(prompt): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git "a/docs/03 \342\200\220 Parameters Tab.md" "b/docs/03 \342\200\220 Parameters Tab.md" index 44abf291c7..d6566aed8f 100644 --- "a/docs/03 \342\200\220 Parameters Tab.md" +++ "b/docs/03 \342\200\220 Parameters Tab.md" @@ -35,6 +35,7 @@ For more information about the parameters, the [transformers documentation](http * **top_p**: If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results. * **top_k**: Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results. * **repetition_penalty**: Penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition. +* **additive_repetition_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. * **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used. * **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text. * **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens. diff --git a/extensions/api/util.py b/extensions/api/util.py index 2e42770d1b..e08c9c7909 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -32,6 +32,7 @@ def build_parameters(body, chat=False): 'tfs': float(body.get('tfs', 1)), 'top_a': float(body.get('top_a', 0)), 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), + 'additive_repetition_penalty': float(body.get('additive_repetition_penalty', body.get('additive_rep_pen', 0))), 'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)), 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), 'top_k': int(body.get('top_k', 0)), diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index 2ebade8272..1115ba97ff 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -10,6 +10,7 @@ 'top_p': 1.0, 'top_k': 1, # choose 20 for chat in absence of another default 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1.0, 'suffix': None, diff --git a/modules/loaders.py b/modules/loaders.py index b76c85dff9..c7e5d80031 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -153,6 +153,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -186,6 +187,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -244,6 +246,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -273,6 +276,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -306,6 +310,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -353,6 +358,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -389,6 +395,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', diff --git a/modules/presets.py b/modules/presets.py index 96d6e994e4..07b78539ee 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -16,6 +16,7 @@ def default_preset(): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1, 'no_repeat_ngram_size': 0, diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 0a724f478c..c0c85c2dec 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -139,11 +139,12 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): Copied from the transformers library ''' - def __init__(self, penalty: float, _range: int): - if not isinstance(penalty, float) or not (penalty > 0): - raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + def __init__(self, penalty: float, additive_penalty: float, _range: int): + if not (penalty > 0): + raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}") self.penalty = penalty + self.additive_penalty = additive_penalty self._range = _range def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -153,6 +154,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * self.penalty, score / self.penalty) + score -= self.additive_penalty scores.scatter_(1, input_ids, score) return scores @@ -185,14 +187,20 @@ def get_logits_warper_patch(self, generation_config): def get_logits_processor_patch(self, **kwargs): - result = self._get_logits_processor_old(**kwargs) - repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range repetition_penalty = kwargs['generation_config'].repetition_penalty + additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty + repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range + do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0) + if do_rep_pen_hijack: + # Make sure that a RepetitionPenaltyLogitsProcessor will be created + kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1 + + result = self._get_logits_processor_old(**kwargs) - if repetition_penalty_range > 0: + if do_rep_pen_hijack: for i in range(len(result)): if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': - result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range) + result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range) return result @@ -205,6 +213,7 @@ def generation_config_init_patch(self, **kwargs): self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_tau = kwargs.pop("mirostat_tau", 5) self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) + self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0) def hijack_samplers(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 295c7cdd6f..b824ccf042 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -273,7 +273,7 @@ def apply_stopping_strings(reply, all_stop_strings): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: generate_params[k] = state[k] if state['negative_prompt'] != '': diff --git a/modules/ui.py b/modules/ui.py index ce92464d54..df9906835d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -105,6 +105,7 @@ def list_interface_input_elements(): 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index d75d420245..15c6c72ee5 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -31,6 +31,7 @@ def create_ui(default_preset): shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') + shared.gradio['additive_repetition_penalty'] = gr.Slider(0, 4, value=generate_params['additive_repetition_penalty'], step=0.05, label='additive_repetition_penalty') shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')