Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Aug 5, 2023
1 parent bff997a commit 5368633
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
16 changes: 14 additions & 2 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def __init__(self, args, layer_id):
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())

self.time_first = nn.Parameter(torch.ones(self.n_head) * (-3.0))
if 'r2' in os.environ["RWKV_MY_TESTING"]:
self.time_faaaa = nn.Parameter(torch.ones(self.n_head) * 0.05)
else:
self.time_first = nn.Parameter(torch.ones(self.n_head) * (-3.0))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
Expand Down Expand Up @@ -227,7 +230,11 @@ def forward(self, x):
r, k, v = self.jit_func(x)

w = torch.exp(-torch.exp(self.time_decay.float())).unsqueeze(-1)
u = torch.exp(self.time_first.float()).unsqueeze(-1)

if 'r2' in os.environ["RWKV_MY_TESTING"]:
u = self.time_faaaa.float().unsqueeze(-1)
else:
u = torch.exp(self.time_first.float()).unsqueeze(-1)

################################################################################
########
Expand Down Expand Up @@ -545,6 +552,11 @@ def configure_optimizers(self):
lr_3x.add(n)
else:
lr_2x.add(n)
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif ("time_first" in n) and (args.layerwise_lr > 0):
lr_3x.add(n)
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
Expand Down
27 changes: 21 additions & 6 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,43 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)

if args.my_exit_tokens > 0: # cosine decay
if args.my_exit_tokens != 0: # cosine decay
if trainer.global_step < w_step:
lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step)
else:
real_tokens = real_step * args.ctx_len * args.real_bsz
warmup_tokens = w_step * args.ctx_len * args.real_bsz
progress = (real_tokens - warmup_tokens) / (args.my_exit_tokens - warmup_tokens)
progress = (real_tokens - warmup_tokens) / (abs(args.my_exit_tokens) - warmup_tokens)
progress = max(0, min(1, progress))
lr_final_factor = 0.1
lr_final_factor = args.lr_final / args.lr_init
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress)
lr = args.lr_init * lr_mult
if args.my_exit_tokens > 0:
lr = args.lr_init * lr_mult
else:
lr = (lr + args.lr_init * lr_mult) / 2
if progress >= 1:
my_save(
pl_module.state_dict(),
f"{args.proj_dir}/rwkv-final.pth",
)
exit(0)

if args.weight_decay_final > 0:
wd_now = args.weight_decay * math.exp(math.log(args.weight_decay_final / args.weight_decay) * progress)
else:
wd_now = args.weight_decay

for param_group in trainer.optimizers[0].param_groups:
if param_group["weight_decay"] > 0:
param_group["weight_decay"] = wd_now
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr

trainer.my_lr = lr
trainer.my_wd = wd_now
# rank_zero_info(f"{real_step} {lr}")

if trainer.global_step == 0:
Expand Down Expand Up @@ -127,7 +138,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# self.log("s", real_step, prog_bar=True, on_step=True)

if len(args.wandb) > 0:
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "wd": trainer.my_wd, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
Expand Down Expand Up @@ -187,7 +198,11 @@ def generate_init_weight(model, init_weight_name):
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
try:
assert k in mm
except:
print('missing', k)
exit(0)
src = load_dict[k]
try:
mm[k] = src.reshape(mm[k].shape)
Expand Down
18 changes: 11 additions & 7 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import logging
logging.basicConfig(level=logging.INFO)

if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -83,6 +86,7 @@
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--dropout", default=0, type=float)
parser.add_argument("--weight_decay", default=0, type=float) # try 0.1 / 0.01 / 0.001
parser.add_argument("--weight_decay_final", default=-1, type=float)

parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
Expand Down Expand Up @@ -110,7 +114,7 @@
parser.add_argument("--my_random_steps", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser.add_argument("--my_exit", default=99999999, type=int)
parser.add_argument("--my_exit_tokens", default=-1, type=int)
parser.add_argument("--my_exit_tokens", default=0, type=int)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -301,12 +305,12 @@
from src.model_img import RWKV_IMG
model = RWKV_IMG(args)
else:
if args.dropout > 0:
from src.model_drop2 import RWKV
model = RWKV(args)
else:
from src.model import RWKV
model = RWKV(args)
# if args.dropout > 0:
# from src.model_drop2 import RWKV
# model = RWKV(args)
# else:
from src.model import RWKV
model = RWKV(args)

if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
Expand Down

0 comments on commit 5368633

Please sign in to comment.