Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weight initialization in LoRA and Adapter finetuning #117

Merged
merged 1 commit into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def main():

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

with EmptyInitOnDevice(device=fabric.device, dtype=torch.bfloat16):
model = LLaMA(config)
with fabric.device:
torch.set_default_tensor_type(torch.HalfTensor)
lantiga marked this conversation as resolved.
Show resolved Hide resolved
model = LLaMA(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# strict=False because missing keys due to adapter weights not containted in state dict
model.load_state_dict(checkpoint, strict=False)

Expand Down
13 changes: 7 additions & 6 deletions finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ def main():
config = LLaMAConfig.from_name("7B")
config.block_size = block_size

with EmptyInitOnDevice(device=fabric.device, dtype=torch.bfloat16):
with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
model = LLaMA(config)

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

with fabric.device, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMA(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)

# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
Expand Down
2 changes: 1 addition & 1 deletion lit_llama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)

amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool)
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
y = y + self.gating_factor * ay

Expand Down