Skip to content

Commit

Permalink
Fix issue with full diffusers SD3 loras.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 20, 2024
1 parent 0d6a579 commit 028a583
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,19 @@ def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]

if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset))
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches

self.patches_uuid = uuid.uuid4()
Expand Down Expand Up @@ -347,6 +350,9 @@ def calculate_weight(self, patches, weight, key):
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a

old_weight = None
if offset is not None:
Expand All @@ -371,7 +377,7 @@ def calculate_weight(self, patches, weight, key):
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
Expand All @@ -389,9 +395,9 @@ def calculate_weight(self, patches, weight, key):
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
Expand Down Expand Up @@ -435,9 +441,9 @@ def calculate_weight(self, patches, weight, key):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
Expand Down Expand Up @@ -472,9 +478,9 @@ def calculate_weight(self, patches, weight, key):
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
Expand All @@ -493,9 +499,9 @@ def calculate_weight(self, patches, weight, key):
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
Expand Down

0 comments on commit 028a583

Please sign in to comment.