Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pip package RWKV.model and v2/chat.py to support LoRA #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rwkv_pip_package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rwkv"
version = "0.7.3"
version = "0.7.4"
authors = [
{ name="Bo PENG" },
]
Expand Down
29 changes: 28 additions & 1 deletion rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
########################################################################################################

class RWKV(MyModule):
def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None):
def __init__(self, model, strategy, lora, verbose = True, convert_and_save_and_exit = None):
super().__init__()
if verbose:
prxxx = lambda *args, **kwargs: print(*args, **kwargs)
Expand Down Expand Up @@ -102,6 +102,33 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
gc.collect()
w = self.w

if lora.lora_r > 0:
prxxx(f'Loading lora ...')
# merge LoRA-only slim checkpoint into the main weights
w_lora = torch.load(lora.MODEL_LORA + '.pth', map_location='cpu')
for k in w_lora.keys():
w[k] = w_lora[k]
# merge LoRA weights
keys = set(w.keys())
for k in keys:
k: str
if k.endswith('.weight'):
prefix = k[:-len('.weight')]
lora_A = prefix + '.lora_A'
lora_B = prefix + '.lora_B'
if lora_A in keys:
assert lora_B in keys
print(f'merging {lora_A} and {lora_B} into {k}')
assert w[lora_B].shape[1] == w[lora_A].shape[0] == lora.lora_r
# merging needs matmul, which is slow on cpu; work on gpu if possible
if lora.RUN_DEVICE == 'cuda':
w[k] = w[k].cuda()
w[lora_A] = w[lora_A].cuda()
w[lora_B] = w[lora_B].cuda()
w[k] += w[lora_B] @ w[lora_A] * (lora.lora_alpha / lora.lora_r)
del w[lora_A]
del w[lora_B]

ALREADY_CONVERTED = False
if '_strategy' in w:
ALREADY_CONVERTED = True
Expand Down
10 changes: 9 additions & 1 deletion v2/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
# args.strategy = 'cuda fp16i8 -> cpu fp32 *10'
# args.strategy = 'cuda fp16i8 *10+'


lora = types.SimpleNamespace()
lora.MODEL_LORA = './cp/rwkv-10'
lora.lora_r = 0 #r = 0 for no LORA
lora.lora_alpha = 8
lora.RUN_DEVICE = "cuda"


os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

Expand Down Expand Up @@ -119,7 +127,7 @@
# Load Model

print(f'Loading model - {args.MODEL_NAME}')
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy, lora=lora)
if not PILE_v2_MODEL:
pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json")
END_OF_TEXT = 0
Expand Down