Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA: zero_pad speed improvements #770

Merged

Conversation

Andrei-Aksionov
Copy link
Collaborator

@Andrei-Aksionov Andrei-Aksionov commented Nov 23, 2023

While experimenting with GitHub actions I deleted my fork (I know, I know) and thus all opened PRs were automatically closed. This PR mirrors #630.

Hi there 👋

This PR is a result of #461.
In that issue I've found out that the creation of a new tensor with lora_ind (that are stored as a python list on a CPU) for each zero_pad call ...
https://github.com/Lightning-AI/lit-gpt/blob/807c7bc17413d53961f96dc668aa03c0b970a43f/lit_gpt/lora.py#L293-L295

... implicitly calls cudaStreamSynchronize every time and that slows down the forward pass a bit.

Traces

Note

Number are provided for the Nvidia T4 and 16-mixed precision.

Let's take a look at the traces for Pythia-410m.

Currently zero_pad takes a significant part of the time:
Screenshot 2023-10-09 at 7 05 55 PM

Note

Compare the size of cudaStreamSynchronize from the screenshot above (CUDA 12.1) and the one from the "Performance Study" issue (CUDA 11.8) - it's much smaller thanks to the newest CUDA.

After the code is optimized, from the trace we can see that the zero_pad now takes much less portion of the time:
Screenshot 2023-10-09 at 7 08 53 PM

In numbers, it's 830 μs vs 126 μs.

LoRA fine-tuning

If to compare LoRA fine-tuning with Pythia-410m and 1k iterations, we have:

Model Loss $_{control}$ Loss $_{test}$ Time $_{control}$ Time $_{test}$
Pythia-70m 2.5835 2.5802 30.90 28.51
Pythia-410m 1.7976 1.7976 124.63 114.51

Not a drastic difference, but still a nice optimization.

lit_gpt/lora.py Outdated Show resolved Hide resolved
lit_gpt/lora.py Outdated Show resolved Hide resolved
@Andrei-Aksionov Andrei-Aksionov marked this pull request as draft May 6, 2024 13:27
@Andrei-Aksionov Andrei-Aksionov marked this pull request as ready for review May 6, 2024 14:01
@Andrei-Aksionov
Copy link
Collaborator Author

I did a very quick benchmarking with Pythia-410m on 1xT4 between the code from this PR and the current main.
Loss the same, memory ever so slightly smaller, time to do a short run of 20 steps is comparable.
So I guess it's a green light.

It would be nice if someone with an access to a multi-GPU machine could run a quick LoRA finetune just to make sure.
In a Studio I have access only to a single GPU.

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'll run finetuning

@carmocca carmocca merged commit f84b610 into Lightning-AI:main May 6, 2024
9 checks passed
@Andrei-Aksionov Andrei-Aksionov deleted the lora_zero_pad_speed_improvements branch May 6, 2024 17:17
x = torch.randint(0, config.padded_vocab_size, size=(2, config.block_size), dtype=torch.int64, device=fabric.device)
model = fabric.setup(model)
y = model(x)
assert y.shape == torch.Size([2, 8, 512])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Andrei-Aksionov Could we maybe add a sanity test that iterates over all model attributes of all submodules and asserts that if it's a tensor then .is_meta is False? The previous bug wasn't caught simply because defaults were all lora kqv were True, which would essentially skip this code path:

litgpt/litgpt/lora.py

Lines 330 to 342 in 90a16e4

if all(self.enable_lora):
return x
# Let's image that:
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
# ⚬ embeddings_size: 128
# ⚬ self.linear.out_features: 384 (3 * embeddings_size)
# ⚬ enable_lora: [True, False, True]
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where self.lora_ind comes in handy)
result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384)
return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.
Sure, I'll do this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants