Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nano-rwkv rnn mode #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

N1kSt4r
Copy link

@N1kSt4r N1kSt4r commented Dec 16, 2023

I thought this code would be useful. I have implemented a RNN mode, but I got acquainted with the RWKV project just a few hours ago, so I'm a little unsure about my implementation and I want it to be reviewed.

Known issues:

  • Forward change - temporarily removed positional embedding
  • There are also magic constants in dimensions somewhere

I will clean everything up in a clean implementation, if you are interested in it.
I checked on the demo version of the Shakespeare model

Copy link

@SmerkyG SmerkyG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't test it, but looks almost right to me upon visual review.

w and u need to be reshaped to function properly for RWKV5.1. They should be something like w, u = w.view(1, H, 1, 1), u.view(1, H, 1, 1)

One problem I did see is that the variable you call state (x_state from my rwkv_explained repo) isn't used correctly for looping RNN mode in circumstances where the loop length C > 1. x_state is intended to include only the most recent input token processed, so xx = state - x doesn't make sense in that context. You would need to accept a longer x in time terms and concatenate (x_state, x_shifted) to do that. That said, I'm not sure there's a reason to support a loop in RNN mode, since it's only useful for training. If you're trying to add support for infctx training, that would be better done by augmenting the forward() call to accept state, etc. C, T calcs have some temporary testing code override the proper definitions, and if you want to support calling with T>1 you need some changes to xx

There's some test code you left in at the bottom that should be removed for a full PR. To ensure correctness I'd suggest adding and running a test that recurrently processes some inputs to create outputs by calling the step function many times, and compares the results to the output from running the original code.

@BlinkDL The removal of posemb is important and probably best submitted as a separate PR.

@N1kSt4r
Copy link
Author

N1kSt4r commented Dec 19, 2023

Thank you for your review!
I have fixed all problems, that were mentioned
But I also left two different methods for transformer and rnn modes

@N1kSt4r N1kSt4r changed the title initially add nano-rwkv rnn mode add nano-rwkv rnn mode Dec 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants