Skip to content

Commit

Permalink
correct implementation for DoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Mar 16, 2024
1 parent 1fa03b8 commit 47922c5
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 12 deletions.
20 changes: 16 additions & 4 deletions lycoris/modules/locon.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,15 @@ def __init__(
self.wd = weight_decompose
if self.wd:
org_weight: nn.Parameter = org_module.weight
self.dora_mean_dim = tuple(i for i in range(org_weight.dim()) if i != 1)
self.dora_norm_dims = org_weight.dim() - 1
self.dora_scale = nn.Parameter(
torch.mean(org_weight, dim=self.dora_mean_dim, keepdim=True)
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()

if dropout:
Expand Down Expand Up @@ -190,10 +196,16 @@ def make_weight(self, device=None):
return weight * self.scalar.to(device)

def apply_weight_decompose(self, weight):
return weight * (
self.dora_scale / weight.mean(dim=self.dora_mean_dim, keepdim=True)
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)

return weight * (self.dora_scale / weight_norm)

def custom_state_dict(self):
destination = {}
if self.wd:
Expand Down
20 changes: 16 additions & 4 deletions lycoris/modules/loha.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,15 @@ def __init__(
self.wd = weight_decompose
if self.wd:
org_weight: nn.Parameter = org_module.weight
self.dora_mean_dim = tuple(i for i in range(org_weight.dim()) if i != 1)
self.dora_norm_dims = org_weight.dim() - 1
self.dora_scale = nn.Parameter(
torch.mean(org_weight, dim=self.dora_mean_dim, keepdim=True)
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()

self.dropout = dropout
Expand Down Expand Up @@ -261,10 +267,16 @@ def get_weight(self, shape):
return weight

def apply_weight_decompose(self, weight):
return weight * (
self.dora_scale / weight.mean(dim=self.dora_mean_dim, keepdim=True)
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)

return weight * (self.dora_scale / weight_norm)

def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
Expand Down
31 changes: 27 additions & 4 deletions lycoris/modules/lokr.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,15 @@ def __init__(
self.wd = weight_decompose
if self.wd:
org_weight: nn.Parameter = org_module.weight
self.dora_mean_dim = tuple(i for i in range(org_weight.dim()) if i != 1)
self.dora_norm_dims = org_weight.dim() - 1
self.dora_scale = nn.Parameter(
torch.mean(org_weight, dim=self.dora_mean_dim, keepdim=True)
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()

self.dropout = dropout
Expand Down Expand Up @@ -326,10 +332,16 @@ def get_weight(self, shape):
return weight

def apply_weight_decompose(self, weight):
return weight * (
self.dora_scale / weight.mean(dim=self.dora_mean_dim, keepdim=True)
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)

return weight * (self.dora_scale / weight_norm)

def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
Expand Down Expand Up @@ -454,6 +466,17 @@ def forward(self, x):
test_output = lokr(test_input)
print(test_output.shape)

# opt = torch.optim.AdamW(lokr.parameters(), lr=1e-2)
# for _ in range(100):
# x = torch.randn(128, 128).cuda()
# t = x / 10
# y = lokr(x)
# loss = F.mse_loss(y, t)
# loss.backward()
# opt.step()
# opt.zero_grad()
# print(loss.item())

base_4bit = LinearNF4(128, 128)
base_4bit.load_state_dict(base.state_dict())
base_4bit.cuda()
Expand Down

0 comments on commit 47922c5

Please sign in to comment.