Skip to content

Commit

Permalink
+ dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Aug 6, 2023
1 parent 69e6c50 commit b42fc10
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
23 changes: 19 additions & 4 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def __init__(self, args, layer_id):
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))

if args.dropout > 0:
self.drop0 = nn.Dropout(p = args.dropout)
self.drop1 = nn.Dropout(p = args.dropout)

def forward(self, x, x_emb=None):
args = self.args
B, T, C = x.size()
Expand All @@ -476,11 +480,18 @@ def forward(self, x, x_emb=None):
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
x = x + pos_emb

if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
if self.args.dropout == 0:
if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
if self.layer_id == 0 and args.pre_ffn > 0:
x = self.drop0(x + self.ffnPre(self.ln1(x)))
else:
x = self.drop0(x + self.att(self.ln1(x)))
x = self.drop1(x + self.ffn(self.ln2(x)))

if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
xx = self.tiny_ln(x)
Expand Down Expand Up @@ -533,6 +544,8 @@ def __init__(self, args):
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
if args.dropout > 0:
self.drop0 = nn.Dropout(p = args.dropout)

def configure_optimizers(self):
args = self.args
Expand Down Expand Up @@ -617,6 +630,8 @@ def forward(self, idx):
x = self.emb(idx)
x_emb = x

if args.dropout > 0:
x = self.drop0(x)
if args.tiny_att_dim > 0:
for block in self.blocks:
if args.grad_cp == 1:
Expand Down
6 changes: 1 addition & 5 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--dropout", default=0, type=float)
parser.add_argument("--dropout", default=0, type=float) # try 0.01 / 0.02 / 0.05 / 0.1
parser.add_argument("--weight_decay", default=0, type=float) # try 0.1 / 0.01 / 0.001
parser.add_argument("--weight_decay_final", default=-1, type=float)

Expand Down Expand Up @@ -305,10 +305,6 @@
from src.model_img import RWKV_IMG
model = RWKV_IMG(args)
else:
# if args.dropout > 0:
# from src.model_drop2 import RWKV
# model = RWKV(args)
# else:
from src.model import RWKV
model = RWKV(args)

Expand Down

0 comments on commit b42fc10

Please sign in to comment.