Skip to content

Commit

Permalink
Add support for cp decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Mar 12, 2023
1 parent 868660d commit 9abae9e
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def forward(self, x):
class LoraUpDownModule:
def __init__(self):
self.up_model = None
self.mid_model = None
self.down_model = None
self.alpha = None
self.dim = None
Expand All @@ -166,13 +167,23 @@ def inference(self, x):
**self.extra_args
)
else:
return self.up_model(self.down_model(x))
if self.mid_model is None:
return self.up_model(self.down_model(x))
else:
return self.up_model(self.mid_model(self.down_model(x)))


def pro3(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)


class LoraHadaModule:
def __init__(self):
self.t1 = None
self.w1a = None
self.w1b = None
self.t2 = None
self.w2a = None
self.w2b = None
self.alpha = None
Expand All @@ -191,20 +202,32 @@ def inference(self, x):
bias = self.bias
else:
bias = 0
return self.op(
x,
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
**self.extra_args
)

if self.t1 is None:
return self.op(
x,
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
**self.extra_args
)
else:
return self.op(
x,
(pro3(self.t1, self.w1a, self.w1b)
* pro3(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
**self.extra_args
)


CON_KEY = {
"lora_up.weight",
"lora_down.weight"
"lora_down.weight",
"lora_mid.weight"
}
HADA_KEY = {
"hada_t1",
"hada_w1_a",
"hada_w1_b",
"hada_t2",
"hada_w2_a",
"hada_w2_b",
}
Expand Down Expand Up @@ -262,6 +285,11 @@ def load_lora(name, filename):
lora_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
if lora_key == "lora_down.weight":
if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif lora_key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
elif lora_key == "lora_up.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
Expand All @@ -286,6 +314,8 @@ def load_lora(name, filename):
lora_module.up_model.weight,
lora_module.inference
)
elif lora_key == "lora_mid.weight":
lora_module.mid_model = module
elif lora_key == "lora_down.weight":
lora_module.down_model = module
lora_module.dim = weight.shape[0]
Expand All @@ -304,17 +334,26 @@ def load_lora(name, filename):

if lora_key == 'hada_w1_a':
lora_module.w1a = weight
lora_module.up = FakeModule(
lora_module.w1a,
lora_module.inference
)
if lora_module.up is None:
lora_module.up = FakeModule(
lora_module.w1a,
lora_module.inference
)
elif lora_key == 'hada_w1_b':
lora_module.w1b = weight
lora_module.dim = weight.shape[0]
elif lora_key == 'hada_w2_a':
lora_module.w2a = weight
elif lora_key == 'hada_w2_b':
lora_module.w2b = weight
elif lora_key == 'hada_t1':
lora_module.t1 = weight
lora_module.up = FakeModule(
lora_module.t1,
lora_module.inference
)
elif lora_key == 'hada_t2':
lora_module.t2 = weight

if type(sd_module) == torch.nn.Linear:
lora_module.op = torch.nn.functional.linear
Expand Down

0 comments on commit 9abae9e

Please sign in to comment.