diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 700d482c..b929fbdd 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -19,7 +19,7 @@ try: print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) -except: +except KeyError: os.environ["RWKV_MY_TESTING"] = '' def __nop(ob): @@ -106,32 +106,30 @@ def __init__(self, args, layer_id): assert args.dim_att % self.n_head == 0 self.head_size_divisor = args.head_size_divisor - with torch.no_grad(): - ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 - ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 - ddd = torch.ones(1, 1, args.n_embd) - for i in range(args.n_embd): - ddd[0, 0, i] = i / args.n_embd - - # fancy time_mix - self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) - self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) - self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) - - # fancy time_decay - decay_speed = torch.ones(args.dim_att) - for n in range(args.dim_att): - decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) - self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size)) - # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) - - tmp = torch.zeros(args.dim_att) - for n in range(args.dim_att): - zigzag = ((n + 1) % 3 - 1) * 0.1 - tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag - - self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_mix + self.register_buffer("time_mix_k", torch.pow(ddd, ratio_1_to_almost0)) + self.register_buffer("time_mix_v", torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.register_buffer("time_mix_r", torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + self.register_buffer("time_mix_g", torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for n in range(args.dim_att): + decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.register_buffer("time_decay", decay_speed.reshape(self.n_head, self.head_size)) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + tmp = torch.zeros(args.dim_att) + for n in range(args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag + self.register_buffer("time_faaaa", tmp.reshape(self.n_head, self.head_size)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) @@ -187,13 +185,13 @@ def __init__(self, args, layer_id): self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - with torch.no_grad(): # fancy init of time_mix - ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 - ddd = torch.ones(1, 1, args.n_embd) - for i in range(args.n_embd): - ddd[0, 0, i] = i / args.n_embd - self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + self.register_buffer("time_mix_k", torch.pow(ddd, ratio_1_to_almost0)) + self.register_buffer("time_mix_r", torch.pow(ddd, ratio_1_to_almost0)) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) @@ -216,18 +214,16 @@ def __init__(self, args, layer_id): self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - with torch.no_grad(): - ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) - - x = torch.ones(1, 1, args.n_embd) - for i in range(args.n_embd): - x[0, 0, i] = i / args.n_embd + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd + self.register_buffer("time_mix_k", torch.pow(x, ratio_1_to_almost0)) + self.register_buffer("time_mix_r", torch.pow(x, ratio_1_to_almost0)) - self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) - self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) - self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) + self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x):