Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Support DDP for activation generation and SAE training. #14

Open
1 task
Hzfinfdu opened this issue Jun 6, 2024 · 4 comments
Open
1 task
Assignees
Labels
enhancement New feature or request

Comments

@Hzfinfdu
Copy link
Member

Hzfinfdu commented Jun 6, 2024

A natural approach to faster SAE training is data parallel. Maybe we can just simply use DDP to make 8 copies of the TL model to yield activation and synchronize SAE gradients. This may help accelerate activation gen, which is the speed bottleneck for larger LMs.

This may not work on larger size models, say 70B models. Maybe the ultimate solution is a producer-consumer design pattern. Let's leave this for later.

  • Support DDP
@Hzfinfdu Hzfinfdu self-assigned this Jun 6, 2024
@dest1n1s dest1n1s added the enhancement New feature or request label Jun 8, 2024
@alanxmay
Copy link

@Hzfinfdu Hi, thanks for your amazing work.

Is DDP working now? I tried with 4 GPU, but found process on device (1,2,3) never end.

@alanxmay
Copy link

BTW, I did some modification to get over a bug in the ddp code.

Error message without modification:

...
[rank3]:   File "/home/alan/dev/sae/Language-Model-SAEs/TransformerLens/transformer_lens/components/embed.py", line 34, in forward
[rank3]:     return self.W_E[tokens, :]
[rank3]: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)

@Hzfinfdu
Copy link
Member Author

Hzfinfdu commented Jun 13, 2024

@alanxmay Thanks for your comment!

We are currently working on it. Initially we did not take industry-size models into account, neither was DDP. We may have to refactor for about a week to work with that.

8B models just work on an A100 GPU with a small batch size.

If this does not fit in your scenario, you may have to wait for a while xd.

@alanxmay
Copy link

@Hzfinfdu Thanks for your replay, my setup is 8*V100(32G) sadly 🤷‍♂️.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants