Efficient Sampling #32
Comments
Per Stack Overflow, apparently you cannot have multiple assignees to a private repo. No idea why. @Mistobaan @lucidrains should assure you both get notifications. |
reopening this because there are still some problems
Fast sampling:
I'm happy to look into this if you're busy @lucidrains . I suspect 1) is just an off by one error somewhere with the attention mask, maybe. 2) might be a little trickier but I'd be happy to look into this one, since TPUs might be the one area where I can claim to have more experience, haha. I think it may be a shape error relating to the mtf splitting. |
After a bit more digging, it seems like the sample quality is only worse with MOE models. I'm going to keep investigating as to why. |
Ok, for now we have set fast sampling as the default, but revert to slow sampling for moe models. Error no. 2 wasn't in fact an error with larger pods, or models, but was an error with fast sampling when recompute_grad was set to True. This is fixed here: 8415e56 so what's left now is to 1) figure out what's wrong with moe and 2) try out step 4 here: #43 |
closing this issue to open a more relevant one w/r/t moe |
Currently our sampling is incredibly inefficient, doesn't store past values for k / v, and instead recomputes them for every token.
We should look again at how sampling is done in mesh / T5 (maybe ask Colin?) and see if we can store k/v values to increase the efficiency of our sampling code.
The basic infrastructure for this is already in place, but commented out (the k/v values will be stored in this Context object: https://github.com/EleutherAI/GPTNeo/blob/master/sample.py#L148, https://github.com/EleutherAI/GPTNeo/blob/master/models/gpt2/gpt2.py#L138). The problem we ran into last time was that in the mesh code they seem to feed in the inputs a token at a time, but the token is missing a batch dimension, and we're not exactly sure where batch gets added on, and therefore how to replicate this. (see: https://github.com/tensorflow/mesh/blob/a3c05f705641dfe144f70b7b5230db4933ce8ca9/mesh_tensorflow/transformer/transformer.py#L1137)
This is (the last?) issue I'd like to get solved before the code is released publicly, so I will try to look into this soon, but any help would be appreciated.
The text was updated successfully, but these errors were encountered: