Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need Forwarding with state. #22

Open
lxianl455 opened this issue Jun 14, 2024 · 5 comments
Open

Need Forwarding with state. #22

lxianl455 opened this issue Jun 14, 2024 · 5 comments

Comments

@lxianl455
Copy link

Translation:
When training, it runs without state:
def forward(self, idx: torch.Tensor) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x) logits = self.lm_head(x) return logits

Can you give a “forward with state” version?
def forward(self, idx: torch.Tensor, state) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x, state) logits = self. lm_head(x) return logits

@sieusaoml
Copy link

sieusaoml commented Jun 16, 2024

Translation: When training, it runs without state: def forward(self, idx: torch.Tensor) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x) logits = self.lm_head(x) return logits

Can you give a “forward with state” version? def forward(self, idx: torch.Tensor, state) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x, state) logits = self. lm_head(x) return logits

https://github.com/sieusaoml/xLSTM-custom-block
a custom block xlstm of my

@lxianl455
Copy link
Author

Yes, I want to do something similar. But in the code, is it only sLSTM that can be initialized with the previous hidden state? Can't mLSTM be initialized with the previous state?

@hiimbach
Copy link

The step() method and the forward() method of mLSTMLayer use different type of conv1d forward, so I think if you want to use hidden state, you need to use step() token by token instead of forward all of tokens at the same time.

@lxianl455
Copy link
Author

lxianl455 commented Jun 16, 2024

Yes, I am not asking to forward all of the tokens at the same time. In fact, my original model was an LSTM, which processes each token in a loop. I just want to replace this LSTM with xLSTM. But it seems that 'step' is used during inference, right? May I ask if it can backpropagate normally during training? Will the inplace operations lead to backpropagation errors?

@sieusaoml
Copy link

mLSTMLayer can be used with the previous hidden state, but backpropagate gradient in my test with context_lenght=1 has an error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants