Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RWKV-v5] use register_buffer instead of frozen params #213

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 41 additions & 45 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

try:
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
except:
except KeyError:
os.environ["RWKV_MY_TESTING"] = ''

def __nop(ob):
Expand Down Expand Up @@ -106,32 +106,30 @@ def __init__(self, args, layer_id):
assert args.dim_att % self.n_head == 0
self.head_size_divisor = args.head_size_divisor

with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd

# fancy time_mix
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))

# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())

tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag

self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd

# fancy time_mix
self.register_buffer("time_mix_k", torch.pow(ddd, ratio_1_to_almost0))
self.register_buffer("time_mix_v", torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.register_buffer("time_mix_r", torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.register_buffer("time_mix_g", torch.pow(ddd, 0.5 * ratio_1_to_almost0))

# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.register_buffer("time_decay", decay_speed.reshape(self.n_head, self.head_size))
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())

tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
self.register_buffer("time_faaaa", tmp.reshape(self.n_head, self.head_size))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
Expand Down Expand Up @@ -187,13 +185,13 @@ def __init__(self, args, layer_id):
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
# fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.register_buffer("time_mix_k", torch.pow(ddd, ratio_1_to_almost0))
self.register_buffer("time_mix_r", torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
Expand All @@ -216,18 +214,16 @@ def __init__(self, args, layer_id):
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad():
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)

x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.register_buffer("time_mix_k", torch.pow(x, ratio_1_to_almost0))
self.register_buffer("time_mix_r", torch.pow(x, ratio_1_to_almost0))

self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)

@MyFunction
def forward(self, x):
Expand Down