Skip to content

Commit

Permalink
Fix weight initialization in LoRA and Adapter finetuning (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 11, 2023
1 parent bae421f commit 1f61627
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
6 changes: 4 additions & 2 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def main():

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

with EmptyInitOnDevice(device=fabric.device, dtype=torch.bfloat16):
model = LLaMA(config)
with fabric.device:
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMA(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# strict=False because missing keys due to adapter weights not containted in state dict
model.load_state_dict(checkpoint, strict=False)

Expand Down
13 changes: 7 additions & 6 deletions finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ def main():
config = LLaMAConfig.from_name("7B")
config.block_size = block_size

with EmptyInitOnDevice(device=fabric.device, dtype=torch.bfloat16):
with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
model = LLaMA(config)

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

with fabric.device, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMA(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)

# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
Expand Down
2 changes: 1 addition & 1 deletion lit_llama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)

amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool)
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
y = y + self.gating_factor * ay

Expand Down

0 comments on commit 1f61627

Please sign in to comment.