Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongLora fine-tuning support #1237

Open
belerico opened this issue Apr 3, 2024 · 5 comments
Open

LongLora fine-tuning support #1237

belerico opened this issue Apr 3, 2024 · 5 comments

Comments

@belerico
Copy link
Contributor

belerico commented Apr 3, 2024

LongLora is "an efficient fine-tuning approach that extends the context sizes of pre-trained large language models". They propose to fine-tune a model with a sparse local attention while maintaining dense attention during inference. The Shifted-Sparse Attention (S^2-Attn) is depicted in the following (from the paper):

image

Moreover, the implied modification is relatively simple:

# B: batch size; 
# S: sequence length or number of tokens; 
# G: group size;
# H: number of attention heads; 
# D: dimension of each attention head
# qkv in shape (B, N, 3, H, D), projected queries, keys, and values

# Key line 1: split qkv on H into 2 chunks, and shift G/2 on N
qkv = cat((qkv.chunk(2, 3)[0], qkv.chunk(2, 3)[1].roll(-G/2, 1)), 3).view(B*N/G,G,3,H,D)

# standard self-attention function
out = self_attn(qkv)

# out in shape (B, N, H, D)
# Key line 2: split out on H into 2 chunks, and then roll back G/2 on N
out = cat((out.chunk(2, 2)[0], out.chunk(2, 2)[1].roll(G/2, 1)), 2)

This can be effectively enabled only during the fine-tuning phase while the standard dense attention can be used during inference.

Another thing that should be modified is the padded sequence length, which should be a multiple of the group-size.

If you think that this can be added to lit-gpt, I'm willing to contribute with a PR (I've already something working which I plan to test)

Edit:

I forgot to mention that they also use the Position Interpolation to rescale the position indices. If I'm not mistaken this can be achieved by simply change the rope_condense_ratio to account for the increased contex-size

@belerico
Copy link
Contributor Author

belerico commented Apr 4, 2024

I've put together something here.

To further reduce the memory consumption I've also added the possibility to remove the last n layers in the model, as specified in "The Unreasonable Ineffectiveness of the Deeper Layers", Sec. 4.4.

I've trained a model on a Lightning Studio on a A10G with the following hyperparameters:

litgpt finetune lora \
--config "/teamspace/studios/this_studio/litgpt/config_hub/finetune/mistral-7b/lora.yaml" \
--lora_r 16 \
--lora_dropout 0 \
--lora_query true \
--lora_key true \
--lora_value true \
--lora_projection true \
--longlora_n_groups 4 \
--longlora_context_length 8192 \
--logger_name "tensorboard" \
--data.pad_multiple_of 4 \
--checkpoint_dir "/teamspace/studios/this_studio/checkpoints/mistralai/Mistral-7B-Instruct-v0.1/" \
--train.micro_batch_size 1 \
--train.max_seq_length 8192 \
--train.remove_last_perc_layers 0.0 \
--train.get_longest_seq_length true \
--train.trainable_params "wte,norm_" \
--precision "bf16-true"

With those settings I have:

  • model.block_size=8192 with an adjusted rope_condense_ratio=2.0
  • The maximum sequence length (even though set to 8192 from the CLI) equals to 728, which has been padded to the nearest integer multiple of 4 (LongLora number of groups)
  • LongLora Shift-Sparse Attention is activated (longlora_n_groups=4), even though is not strictly necessary in this use case since a sequence length of 8192 is never encountered

Those are two generations:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie to watch on the weekend.

### Response:
One great movie to watch on the weekend is "The Shawshank Redemption". It's a classic drama with a compelling story, great acting, and a positive message.

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie to watch on the weekend.

### Response:
One great movie to watch on the weekend is "The Shawshank Redemption". It's a timeless classic with a compelling storyline and excellent performances.

@rasbt
Copy link
Collaborator

rasbt commented Apr 4, 2024

Thanks for sharing and writing-up this thorough description. I saw the paper a few months back but must admit that I didn't have time to read.

Btw I am all in terms of supporting interesting research techniques that help around real & common issues (e.g., high memory requirements, LLMs not being able to handle long contexts, etc.)

In general, something I am wondering about is if it's really LoRA specific, or could it also be used with "full"-parameter finetuning?

The --train.remove_last_perc_layers is also a nice to have. I'd would apply it in a separate PR, and I think it's useful to have.

What do you think @awaelchli @carmocca ?

@belerico
Copy link
Contributor Author

belerico commented Apr 4, 2024

Hi @rasbt,

In general, something I am wondering about is if it's really LoRA specific, or could it also be used with "full"-parameter finetuning?

Even though in the paper they have specifically designed everything for fine-tuning with LoRA, it's something that I also thought about. The concern that I have is the flow of information between the first and the last token, which is mitigated during the fine-tuning since the pre-training has already been done on a ton of data. The authors ablate this in B.3 and found that it doesn't influence the finetuning. Maybe it can be applied during pre-training by adopting the Variant-2 in B.3, where they use a separate group in the shifted tokens?

The --train.remove_last_perc_layers is also a nice to have. I'd would apply it in a separate PR, and I think it's useful to have.

Sure

@belerico
Copy link
Contributor Author

Hi guys, I'm catching up here. I've spotted a little bug due to a missing reshape and I've also implemented LongLora for the full-finetune. If it's ok from you I'll open a PR

cc @rasbt @carmocca @awaelchli

@rasbt
Copy link
Collaborator

rasbt commented Apr 23, 2024

To me both would be welcome and valuable contributions :).
I would maybe do both in separate PRs as that would make the code review a bit easier

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants