Add a function to convert nanogpt weights #475
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
I have been using TransformerLens on some models I trained using Karpathy's popular nanogpt repository. There are two complications I had to deal with:
The first is that some state dicts saved after using torch.compile() have an unwanted prefix on keys that needs to be removed. Karpathy deals with it like this. I added this same unwanted prefix removal in the conversion function.
The second is that the nanogpt models can be created with or without bias. By default, there is no bias. This function can handle both cases. To verify that my conversion function works as expected, I created this Colab:
https://colab.research.google.com/drive/1CqMRAezkc2vVJKPiKA3Q7ACdjgzH_y7K?authuser=0#scrollTo=4pB3Ecg7B0X-
In it, I take models that I have created and trained, one with and one without bias. For each stock NanoGPT model, I run a sample input of length 339 and run it through the model. I store the 339 output tokens in expected output. Next, I convert the model to Transformer Lens format using my conversion function. I again forward the same sample input, and check that the outputs exactly match the expected output.
In terms of tests, documentation, comments, code style conventions, etc, I just matched the level of coverage of what shows up when I search 'mingpt' in the codebase. Please let me know if you want additional documentation or testing.
Type of change