Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex rotary memory #185

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,66 @@ def forward(self, x):
x = RUN_CUDA(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)

return self.jit_func_2(x, g)

########################################################################################################
# RWKV: RWKV Wavenet-mem + rotary memory
########################################################################################################
class Short_Mem(nn.Module):
def __init__(self, args, shiftAmount=1):
super().__init__()
# self.time_shift1 = TimeShift(args.n_embd, shiftAmount=shiftAmount, batch=args.micro_bsz)
self.time_shift1 = nn.ZeroPad2d((0, 0, shiftAmount, -shiftAmount))
self.activation = nn.Sequential(
nn.Linear(args.n_embd*2, args.n_embd, bias=False),
nn.Sigmoid(),
)

def forward(self, x):
xv = self.activation(torch.cat([self.time_shift1(x),x], dim=-1))
return xv

class WaveNet_Mem(Short_Mem):
def __init__(self, args, layer_id, modulo=12, undialated=False):
if undialated:
super().__init__(args, shiftAmount=1)
else:
super().__init__(args, shiftAmount=2**(layer_id%modulo))

class Rotary_Memory(nn.Module):
def __init__(self, args, layer_id):
nn.Module.__init__(self)
self.args = args
self.layer_id = layer_id

self.complexsize = args.n_embd
self.short = WaveNet_Mem(args, layer_id, undialated=True)
self.key = nn.Linear(args.n_embd,self.complexsize*2, bias=False, dtype=torch.bfloat16)
# self.cumprod = CumProd(torch.complex(torch.ones(args.micro_bsz, 1, self.complexsize), torch.zeros(args.micro_bsz, 1, self.complexsize)))
# self.cummax = CumMax()
self.activation = nn.Linear(self.complexsize*2, args.n_embd, bias=False, dtype=torch.bfloat16)

def forward(self, x):
B, T, C = x.size()
k = self.key(x).float()


complexval = torch.view_as_complex(k.reshape(B, T, self.complexsize,2))
# scale = self.cummax(torch.abs(complexval))
scale = torch.cummax(torch.abs(complexval), dim=-2)[0]
complexval2 = complexval / scale
# kv = self.cumprod(complexval2)
kv = torch.cumprod(complexval2, dim=-2)
out = self.activation(torch.view_as_real(kv).reshape(B, T, self.complexsize*2)) * self.short(x* scale)

return out

########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################




class RWKV_TimeMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
Expand Down Expand Up @@ -632,6 +686,11 @@ def __init__(self, args, layer_id):
self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd)

self.use_mem = 'c' in os.environ["RWKV_MY_TESTING"]
if self.use_mem:
self.ln3 = nn.LayerNorm(args.n_embd)
self.mem = Rotary_Memory(args, layer_id)

if self.layer_id == 0:
self.ln0 = nn.LayerNorm(args.n_embd)
if args.my_pos_emb > 0:
Expand Down Expand Up @@ -663,6 +722,8 @@ def __init__(self, args, layer_id):
if args.dropout > 0:
self.drop0 = nn.Dropout(p = args.dropout)
self.drop1 = nn.Dropout(p = args.dropout)
if self.use_mem:
self.drop2 = nn.Dropout(p = args.dropout)

def forward(self, x, x_emb=None):
args = self.args
Expand All @@ -678,12 +739,16 @@ def forward(self, x, x_emb=None):
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
if self.use_mem:
x = x + self.mem(self.ln3(x))
x = x + self.ffn(self.ln2(x))
else:
if self.layer_id == 0 and args.pre_ffn > 0:
x = self.drop0(x + self.ffnPre(self.ln1(x)))
else:
x = self.drop0(x + self.att(self.ln1(x)))
if self.use_mem:
x = self.drop2(x + self.mem(self.ln3(x)))
x = self.drop1(x + self.ffn(self.ln2(x)))

if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
Expand Down