From 457d9d1b8ba91705cd741726d6397c28da7078b9 Mon Sep 17 00:00:00 2001 From: yynil Date: Mon, 31 Jul 2023 12:22:59 +0800 Subject: [PATCH] add a "stops words" support --- DEMO_FOR_STOPWORDS.py | 34 ++++++++++++++++++++++++++++++ rwkv_pip_package/src/rwkv/utils.py | 11 +++++++--- 2 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 DEMO_FOR_STOPWORDS.py diff --git a/DEMO_FOR_STOPWORDS.py b/DEMO_FOR_STOPWORDS.py new file mode 100644 index 00000000..e859a934 --- /dev/null +++ b/DEMO_FOR_STOPWORDS.py @@ -0,0 +1,34 @@ +import os +os.environ['RWKV_JIT_ON'] = '1' +os.environ["RWKV_CUDA_ON"] = '1' +from rwkv.model import RWKV +from rwkv.utils import PIPELINE, PIPELINE_ARGS + +MODEL_FILE = "/media/yueyulin/KINGSTON/pretrained_models/rwkv/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth" +model = RWKV(model=MODEL_FILE, strategy="cuda fp16") +pipeline = PIPELINE(model, "rwkv_vocab_v20230424") #### vocab for rwkv-4-world models +print(model) + +print(pipeline) + +ctx = "User:请根据以下材料设计一道中餐菜谱。要求生成菜名和具体做法,菜谱最后以”完成!“结束。材料:猪后腿肉,青椒,洋葱,盐,胡椒。\nAssistant:菜名:" +print(ctx, end="") + +def my_print(s): + print(s, end="", flush=True) + +end_token = pipeline.encode("完成!") +print(end_token) +args = PIPELINE_ARGS( + temperature=1.5, + top_p=0.3, + top_k=0, # top_k = 0 -> ignore top_k + alpha_frequency=0.2, # frequency penalty - see https://platform.openai.com/docs/api-reference/parameter-details + alpha_presence=0.2, # presence penalty - see https://platform.openai.com/docs/api-reference/parameter-details + token_ban=[], # ban the generation of some tokens + token_stop=end_token, # stop generation at these tokens + chunk_len=256, +) # split input into chunks to save VRAM (shorter -> less VRAM, but slower) + +pipeline.generate(ctx, token_count=1024, args=args, callback=my_print) +print("\n") \ No newline at end of file diff --git a/rwkv_pip_package/src/rwkv/utils.py b/rwkv_pip_package/src/rwkv/utils.py index 4db97700..bbd5550b 100644 --- a/rwkv_pip_package/src/rwkv/utils.py +++ b/rwkv_pip_package/src/rwkv/utils.py @@ -88,7 +88,9 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st out_last = 0 out_str = '' occurrence = {} - for i in range(token_count): + continue_generating = True + i = 0 + while continue_generating: # forward & adjust prob. tokens = self.encode(ctx) if i == 0 else [token] @@ -103,8 +105,8 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st # sampler token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) - if token in args.token_stop: - break + if len(args.token_stop) > 0 and args.token_stop==all_tokens[-len(args.token_stop):]: + continue_generating = False all_tokens += [token] for xxx in occurrence: occurrence[xxx] *= args.alpha_decay @@ -121,4 +123,7 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st callback(tmp) out_str += tmp out_last = i + 1 + i += 1 + if i >= token_count: + continue_generating = False return out_str