The current issue seems to be applying the ctrl subblock. Let's analyze it.

In [1]:
import torch
from torch.testing import assert_close
from torch import allclose, nn, tensor
torch.set_printoptions(linewidth=200)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps'
device_dtype = torch.float16 if device == 'cuda' else torch.float32

## Load the model

In [3]:
from diffusers import StableDiffusionXLPipeline
from diffusers import EulerDiscreteScheduler
from diffusers.models.controlnetxs import ControlNetXSModel
from diffusers.pipelines.controlnet_xs.pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline

In [4]:
sdxl_pipe = StableDiffusionXLPipeline.from_single_file('weights/sdxl/sd_xl_base_1.0_0.9vae.safetensors').to(device)
cnxs = ControlNetXSModel.from_pretrained('weights/cnxs').to(device)

In [5]:
cnxs.base_model = sdxl_pipe.unet

The example script of Heidelberg manually sets scale_list to 0.95

In [6]:
cnxs.scale_list = cnxs.scale_list * 0. + 0.95
assert cnxs.scale_list[0] == .95

Heidelberg uses `timestep_spacing = 'linspace'` in their scheduler, so let's do that as well

In [7]:
scheduler_cgf = dict(sdxl_pipe.scheduler.config)
scheduler_cgf['timestep_spacing'] = 'linspace'
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

# test it worked
sdxl_pipe.scheduler.set_timesteps(50)
assert sdxl_pipe.scheduler.timesteps[0]==999

# reset
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

timestep_spacing = "leading" and timesteps=[999.      978.61224 958.2245  937.83673 917.449  ] ...
sigmas before interpolation: [0.02916753 0.04131448 0.05068044 0.05861427 0.06563709] ...
sigmas after (linear) interpolation: [14.61464691 12.93677721 11.49164976 10.24291444  9.16035419] ...
At end of `set_timesteps`:
sigmas =  tensor([14.6146, 12.9368, 11.4916, 10.2429,  9.1604]) ...
timesteps = tensor([999.0000, 978.6122, 958.2245, 937.8367, 917.4490]) ...


In [8]:
cnxs_pipe = StableDiffusionXLControlNetXSPipeline(
    vae=sdxl_pipe.vae,
    text_encoder=sdxl_pipe.text_encoder,
    text_encoder_2=sdxl_pipe.text_encoder_2,
    tokenizer=sdxl_pipe.tokenizer,
    tokenizer_2=sdxl_pipe.tokenizer_2,
    unet=sdxl_pipe.unet,
    controlnet=cnxs,
    scheduler=sdxl_pipe.scheduler,
)

___

In [9]:
from util_inspect import load_intermediate_outputs, print_metadata, compare_intermediate_results

In [22]:
model_outp_cloud = load_intermediate_outputs('intermediate_output/cloud_debug_log.pkl')
model_outp_local = load_intermediate_outputs('intermediate_output/local_debug_log.pkl')
len(model_outp_cloud),len(model_outp_local)

(72, 72)

In [52]:
compare_intermediate_results(model_outp_cloud, model_outp_local, n=28, n_start=5,prec=5, compare_prec=3)

-  | cloud               | local               | equal name? | equal shape? | equal values? | mean abs Δ
   |                     |                     |             |              |    prec=3     |     prec=5
--------------------------------------------------------------------------------------------------------
5  | enc    h_base       | enc    h_base       | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00008   concatted base -> ctrl
6  | enc    h_ctrl       | enc    h_ctrl       | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00000   applied base subblock
7  | enc    h_ctrl       | enc    h_ctrl       | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00009   applied ctrl subblock
8  | enc    h_base       | enc    h_base       | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00008   added ctrl -> base
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

___

**Q:** Did I maybe also load their weights incorrectly?

In [54]:
first_resblock = cnxs.control_model.down_blocks[0].resnets[0]
first_resblock

ResnetBlock2D(
  (norm1): GroupNorm(32, 352, eps=1e-05, affine=True)
  (conv1): LoRACompatibleConv(352, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=32, bias=True)
  (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (conv2): LoRACompatibleConv(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (nonlinearity): SiLU()
  (conv_shortcut): LoRACompatibleConv(352, 32, kernel_size=(1, 1), stride=(1, 1))
)

In [55]:
param_parts = ('conv1','time_emb_proj','conv2','conv_shortcut')

In [65]:
from diffusers.models.controlnetxs import to_sub_blocks
ctrl_down_subblocks = to_sub_blocks(cnxs.control_model.down_blocks)

In [67]:
len(ctrl_down_subblocks)

8

In [69]:
from util import cls_name

In [77]:
param_parts = {
    'ResnetBlock2D': ['conv1','time_emb_proj','conv2','conv_shortcut'],
    'Downsample2D': ['conv']
}

In [78]:
for i,m in enumerate(ctrl_down_subblocks):
    m = m[0] # unwrap EmbedSequential
    print(f'Subblock {i} - {cls_name(m)}:')
    for p in param_parts[cls_name(m)]:
        first_values = getattr(m,p).weight.flatten().cpu().detach()[:10]
        print(f'{p:<20} {first_values}')
    print()

Subblock 0 - ResnetBlock2D:
conv1                tensor([-0.0225,  0.0219, -0.0374,  0.0109,  0.0028,  0.0212,  0.0164, -0.0133, -0.0110, -0.0135])
time_emb_proj        tensor([-0.0152,  0.0160, -0.0025,  0.0155, -0.0167,  0.0045, -0.0243, -0.0039,  0.0127, -0.0094])
conv2                tensor([-0.0180, -0.0231, -0.0207, -0.0339, -0.0294, -0.0046, -0.0047, -0.0194,  0.0065,  0.0078])
conv_shortcut        tensor([ 0.0269,  0.0493, -0.0816,  0.0297, -0.0484,  0.0082,  0.0849, -0.0301, -0.1588,  0.0405])

Subblock 1 - ResnetBlock2D:
conv1                tensor([ 0.0023, -0.0076, -0.0082,  0.0090,  0.0185,  0.0055, -0.0044, -0.0300,  0.0023,  0.0091])
time_emb_proj        tensor([-0.0142,  0.0132, -0.0026,  0.0191, -0.0068,  0.0182, -0.0076,  0.0256,  0.0187, -0.0159])
conv2                tensor([ 0.0116,  0.0100, -0.0316, -0.0018,  0.0002, -0.0166, -0.0173,  0.0020, -0.0167, -0.0232])
conv_shortcut        tensor([ 0.0024, -0.0124, -0.0311, -0.0614, -0.0990, -0.0994, -0.0391, -0.1036,  0

No downblock is empty

Also, the weights on cloud look identical:
```
Subblock 0 - ResBlock:
conv1                tensor([-0.0225,  0.0219, -0.0374,  0.0109,  0.0028,  0.0212,  0.0164, -0.0133, -0.0110, -0.0135])
time_emb_proj        tensor([-0.0152,  0.0160, -0.0025,  0.0155, -0.0167,  0.0045, -0.0243, -0.0039,  0.0127, -0.0094])
conv2                tensor([-0.0180, -0.0231, -0.0207, -0.0339, -0.0294, -0.0046, -0.0047, -0.0194,  0.0065,  0.0078])
conv_shortcut        tensor([ 0.0269,  0.0493, -0.0816,  0.0297, -0.0484,  0.0082,  0.0849, -0.0301, -0.1588,  0.0405])

Subblock 1 - ResBlock:
conv1                tensor([ 0.0023, -0.0076, -0.0082,  0.0090,  0.0185,  0.0055, -0.0044, -0.0300,  0.0023,  0.0091])
time_emb_proj        tensor([-0.0142,  0.0132, -0.0026,  0.0191, -0.0068,  0.0182, -0.0076,  0.0256,  0.0187, -0.0159])
conv2                tensor([ 0.0116,  0.0100, -0.0316, -0.0018,  0.0002, -0.0166, -0.0173,  0.0020, -0.0167, -0.0232])
conv_shortcut        tensor([ 0.0024, -0.0124, -0.0311, -0.0614, -0.0990, -0.0994, -0.0391, -0.1036,  0.0628,  0.0304])

Subblock 2 - Downsample:
conv                 tensor([ 0.0014, -0.0107, -0.0256, -0.0322, -0.0727, -0.0276, -0.0175, -0.0603, -0.0378, -0.0008])

Subblock 3 - ResBlock:
conv1                tensor([ 0.0038, -0.0225, -0.0137,  0.0118,  0.0261,  0.0445,  0.0046,  0.0254,  0.0123,  0.0196])
time_emb_proj        tensor([ 0.0029,  0.0281,  0.0008,  0.0192,  0.0290,  0.0038, -0.0110,  0.0048, -0.0119,  0.0320])
conv2                tensor([-0.0015,  0.0039,  0.0196,  0.0022,  0.0137,  0.0201,  0.0038,  0.0015,  0.0076, -0.0116])
conv_shortcut        tensor([ 0.0105,  0.0405, -0.0068, -0.0370,  0.0426, -0.0199,  0.0712, -0.0207, -0.0158,  0.0146])

Subblock 4 - ResBlock:
conv1                tensor([ 0.0125,  0.0231,  0.0078,  0.0088,  0.0212, -0.0005,  0.0063, -0.0082, -0.0049, -0.0073])
time_emb_proj        tensor([-0.0077,  0.0140, -0.0198, -0.0020,  0.0209,  0.0031,  0.0255,  0.0376,  0.0209,  0.0419])
conv2                tensor([ 0.0068,  0.0122,  0.0078, -0.0034,  0.0013,  0.0023, -0.0065,  0.0048,  0.0061, -0.0060])
conv_shortcut        tensor([ 0.0409,  0.0009,  0.0166,  0.0155, -0.0032,  0.0379, -0.0255,  0.0058, -0.0301,  0.0224])

Subblock 5 - Downsample:
conv                 tensor([-0.0161, -0.0169, -0.0305, -0.0299, -0.0325, -0.0410, -0.0279, -0.0187, -0.0390,  0.0234])

Subblock 6 - ResBlock:
conv1                tensor([ 0.0206,  0.0154,  0.0105,  0.0079,  0.0243,  0.0134,  0.0070,  0.0179,  0.0059, -0.0110])
time_emb_proj        tensor([-1.9000e-02,  1.4582e-02, -2.6406e-02,  2.2612e-02,  1.5059e-05,  3.3383e-02,  1.0330e-02, -1.6467e-03,  8.3065e-03, -2.8895e-03])
conv2                tensor([ 0.0040,  0.0018,  0.0039, -0.0016, -0.0038, -0.0039, -0.0098, -0.0114, -0.0120,  0.0087])
conv_shortcut        tensor([-0.0092,  0.0251,  0.0145, -0.0126, -0.0410,  0.0210, -0.0353, -0.0262, -0.0098, -0.0207])

Subblock 7 - ResBlock:
conv1                tensor([-9.8836e-05,  1.3444e-02,  1.0174e-02,  4.2135e-03,  5.0547e-03,  4.8702e-03, -7.1490e-03,  6.3899e-03,  8.2130e-03,  5.3117e-03])
time_emb_proj        tensor([-0.0073,  0.0077, -0.0156,  0.0150, -0.0060,  0.0162,  0.0031, -0.0053,  0.0093, -0.0236])
conv2                tensor([ 0.0107, -0.0005, -0.0032,  0.0017, -0.0033, -0.0045,  0.0066,  0.0030,  0.0046, -0.0098])
conv_shortcut        tensor([-0.0185, -0.0004, -0.0143,  0.0176,  0.0277,  0.0025, -0.0249, -0.0079,  0.0086,  0.0037])
```

Hmmmmmm....