Skip to content

Commit

Permalink
Controlnet refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 27, 2024
1 parent 97b409c commit 66aaa14
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 32 deletions.
9 changes: 5 additions & 4 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):

guided_hint = self.input_hint_block(hint, emb, context)

outs = []
out_output = []
out_middle = []

hs = []
if self.num_classes is not None:
Expand All @@ -304,10 +305,10 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
out_output.append(zero_conv(h, emb, context))

h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
out_middle.append(self.middle_block_out(h, emb, context))

return outs
return {"middle": out_middle, "output": out_output}

35 changes: 10 additions & 25 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,12 @@ def inference_memory_requirements(self, dtype):
return self.previous_controlnet.inference_memory_requirements(dtype)
return 0

def control_merge(self, control_input, control_output, control_prev, output_dtype):
def control_merge(self, control, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}

if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)

if control_output is not None:
for key in control:
control_output = control[key]
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
Expand All @@ -120,6 +105,7 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp
x = x.to(output_dtype)

out[key].append(x)

if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
Expand Down Expand Up @@ -182,7 +168,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
return self.control_merge(control, control_prev, output_dtype)

def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
Expand Down Expand Up @@ -490,12 +476,11 @@ def get_control(self, x_noisy, t, cond, batched_number):
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()

control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
mid = None
if self.t2i_model.xl == True:
mid = control_input[-1:]
control_input = control_input[:-1]
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
control_input = {}
for k in self.control_input:
control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))

return self.control_merge(control_input, control_prev, x_noisy.dtype)

def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
Expand Down
2 changes: 1 addition & 1 deletion comfy/ldm/cascade/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def forward(self, x):
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
for i, idx in enumerate(self.proj_blocks):
proj_outputs[idx] = self.projections[i](x)
return proj_outputs
return {"input": proj_outputs[::-1]}
10 changes: 8 additions & 2 deletions comfy/t2i_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ def forward(self, x):
features.append(None)
features.append(x)

return features
features = features[::-1]

if self.xl:
return {"input": features[1:], "middle": features[:1]}
else:
return {"input": features}



class LayerNorm(nn.LayerNorm):
Expand Down Expand Up @@ -290,4 +296,4 @@ def forward(self, x):
features.append(None)
features.append(x)

return features
return {"input": features[::-1]}

0 comments on commit 66aaa14

Please sign in to comment.