diff --git a/rwkv_pip_package/pyproject.toml b/rwkv_pip_package/pyproject.toml index d774bd13..5a0aff7a 100644 --- a/rwkv_pip_package/pyproject.toml +++ b/rwkv_pip_package/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rwkv" -version = "0.7.3" +version = "0.7.4" authors = [ { name="Bo PENG" }, ] diff --git a/rwkv_pip_package/src/rwkv/model.py b/rwkv_pip_package/src/rwkv/model.py index 45f2a295..2fb9510e 100644 --- a/rwkv_pip_package/src/rwkv/model.py +++ b/rwkv_pip_package/src/rwkv/model.py @@ -72,7 +72,7 @@ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): ######################################################################################################## class RWKV(MyModule): - def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): + def __init__(self, model, strategy, lora, verbose = True, convert_and_save_and_exit = None): super().__init__() if verbose: prxxx = lambda *args, **kwargs: print(*args, **kwargs) @@ -102,6 +102,33 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = gc.collect() w = self.w + if lora.lora_r > 0: + prxxx(f'Loading lora ...') + # merge LoRA-only slim checkpoint into the main weights + w_lora = torch.load(lora.MODEL_LORA + '.pth', map_location='cpu') + for k in w_lora.keys(): + w[k] = w_lora[k] + # merge LoRA weights + keys = set(w.keys()) + for k in keys: + k: str + if k.endswith('.weight'): + prefix = k[:-len('.weight')] + lora_A = prefix + '.lora_A' + lora_B = prefix + '.lora_B' + if lora_A in keys: + assert lora_B in keys + print(f'merging {lora_A} and {lora_B} into {k}') + assert w[lora_B].shape[1] == w[lora_A].shape[0] == lora.lora_r + # merging needs matmul, which is slow on cpu; work on gpu if possible + if lora.RUN_DEVICE == 'cuda': + w[k] = w[k].cuda() + w[lora_A] = w[lora_A].cuda() + w[lora_B] = w[lora_B].cuda() + w[k] += w[lora_B] @ w[lora_A] * (lora.lora_alpha / lora.lora_r) + del w[lora_A] + del w[lora_B] + ALREADY_CONVERTED = False if '_strategy' in w: ALREADY_CONVERTED = True diff --git a/v2/chat.py b/v2/chat.py index d9ecf9c7..aba46181 100644 --- a/v2/chat.py +++ b/v2/chat.py @@ -49,6 +49,14 @@ # args.strategy = 'cuda fp16i8 -> cpu fp32 *10' # args.strategy = 'cuda fp16i8 *10+' + +lora = types.SimpleNamespace() +lora.MODEL_LORA = './cp/rwkv-10' +lora.lora_r = 0 #r = 0 for no LORA +lora.lora_alpha = 8 +lora.RUN_DEVICE = "cuda" + + os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries @@ -119,7 +127,7 @@ # Load Model print(f'Loading model - {args.MODEL_NAME}') -model = RWKV(model=args.MODEL_NAME, strategy=args.strategy) +model = RWKV(model=args.MODEL_NAME, strategy=args.strategy, lora=lora) if not PILE_v2_MODEL: pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json") END_OF_TEXT = 0