Let's analyze why the error after the ctrl->base concat in the 1st down block is large (0.05637)

In [1]:
import torch
from torch.testing import assert_close
from torch import allclose, nn, tensor
torch.set_printoptions(linewidth=200)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps'
device_dtype = torch.float16 if device == 'cuda' else torch.float32

## Load the model

In [3]:
from diffusers import StableDiffusionXLPipeline
from diffusers import EulerDiscreteScheduler
from diffusers.models.controlnetxs import ControlNetXSModel
from diffusers.pipelines.controlnet_xs.pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline

In [4]:
sdxl_pipe = StableDiffusionXLPipeline.from_single_file('weights/sdxl/sd_xl_base_1.0_0.9vae.safetensors').to(device)
cnxs = ControlNetXSModel.from_pretrained('weights/cnxs').to(device)

In [5]:
cnxs.base_model = sdxl_pipe.unet

The example script of Heidelberg manually sets scale_list to 0.95

In [6]:
cnxs.scale_list = cnxs.scale_list * 0. + 0.95
assert cnxs.scale_list[0] == .95

Heidelberg uses `timestep_spacing = 'linspace'` in their scheduler, so let's do that as well

In [7]:
scheduler_cgf = dict(sdxl_pipe.scheduler.config)
scheduler_cgf['timestep_spacing'] = 'linspace'
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

# test it worked
sdxl_pipe.scheduler.set_timesteps(50)
assert sdxl_pipe.scheduler.timesteps[0]==999

# reset
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

timestep_spacing = "leading" and timesteps=[999.      978.61224 958.2245  937.83673 917.449  ] ...
sigmas before interpolation: [0.02916753 0.04131448 0.05068044 0.05861427 0.06563709] ...
sigmas after (linear) interpolation: [14.61464691 12.93677721 11.49164976 10.24291444  9.16035419] ...
At end of `set_timesteps`:
sigmas =  tensor([14.6146, 12.9368, 11.4916, 10.2429,  9.1604]) ...
timesteps = tensor([999.0000, 978.6122, 958.2245, 937.8367, 917.4490]) ...


In [8]:
cnxs_pipe = StableDiffusionXLControlNetXSPipeline(
    vae=sdxl_pipe.vae,
    text_encoder=sdxl_pipe.text_encoder,
    text_encoder_2=sdxl_pipe.text_encoder_2,
    tokenizer=sdxl_pipe.tokenizer,
    tokenizer_2=sdxl_pipe.tokenizer_2,
    unet=sdxl_pipe.unet,
    controlnet=cnxs,
    scheduler=sdxl_pipe.scheduler,
)

___

## Load intermediate outputs of 1st encoder block

These are the first 10 values for each step in cloud / local:

Cloud:
```
A1] [ 0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2146]
A2] [-0.0148, -0.3285, -0.0927,  0.0223,  0.0955, -0.2987,  0.0223,  0.0760, -0.0549, -0.7091]
A3] [ 0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2146]
B1] [ 0.3036,  0.2349,  0.2122,  0.2231,  0.4925,  0.0092,  0.3026,  0.3293,  0.0227, -0.1436]
C1] [-0.5792, -0.3964, -0.1935, -0.6617, -0.0559, -0.1097, -0.4058,  0.0838, -0.6492, -1.1216]
D1] [0.0524, 0.0221, 0.0505, 0.0232, 0.0388, 0.0288, 0.0665, 0.0229, 0.0460, 0.0424]
D2] [0.9500]
D3] [ 0.3533,  0.2559,  0.2602,  0.2452,  0.5294,  0.0366,  0.3658,  0.3511,  0.0664, -0.1034]
```

Local:
```
A1] [ 0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2147]
A2] [-0.0147, -0.3285, -0.0927,  0.0223,  0.0955, -0.2987,  0.0222,  0.0760, -0.0548, -0.7091]
A3] [ 0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2147]
B1] [ 0.3032,  0.2355,  0.2123,  0.2237,  0.4928,  0.0095,  0.3033,  0.3297,  0.0226, -0.1435]
C1] [-0.5794, -0.3963, -0.1934, -0.6616, -0.0560, -0.1097, -0.4058,  0.0836, -0.6489, -1.1214]
D1] [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
D2] [0.9500]
D3] [ 0.3032,  0.2355,  0.2123,  0.2237,  0.4928,  0.0095,  0.3033,  0.3297,  0.0226, -0.1435]
```

D1 is wrong in local! D1 is `add_to_base` which is `next(it_enc_convs_out)(h_ctrl)`

This is the 2nd `enc_convs_out` (the 1st is used in the `conv in` part). So let's look at it.

In [11]:
# suspicious connection
sus_connection = cnxs.enc_zero_convs_out[1]

In [12]:
sus_connection

Conv2d(32, 320, kernel_size=(1, 1), stride=(1, 1))

In [19]:
sus_connection.weight.flatten().shape,320*32

(torch.Size([10240]), 10240)

In [22]:
sus_connection.weight.flatten().abs().sum()

tensor(0., device='mps:0', grad_fn=<SumBackward0>)

The connection is empty!

**Q:** Are all connections from ctrl to base empty?

In [23]:
def is_empty(t): return t.weight.flatten().abs().sum()==0

In [26]:
for i, out_con in enumerate(cnxs.enc_zero_convs_out):
    print(f'In encoder, is out connection {i} empty? {is_empty(out_con)}')

In encoder, is out connection 0 empty? False
In encoder, is out connection 1 empty? True
In encoder, is out connection 2 empty? True
In encoder, is out connection 3 empty? True
In encoder, is out connection 4 empty? True
In encoder, is out connection 5 empty? True
In encoder, is out connection 6 empty? True
In encoder, is out connection 7 empty? True
In encoder, is out connection 8 empty? True


In [27]:
for i, in_con in enumerate(cnxs.enc_zero_convs_in):
    print(f'In encoder, is in connection {i} empty? {is_empty(in_con)}')

In encoder, is in connection 0 empty? False
In encoder, is in connection 1 empty? True
In encoder, is in connection 2 empty? True
In encoder, is in connection 3 empty? True
In encoder, is in connection 4 empty? True
In encoder, is in connection 5 empty? True
In encoder, is in connection 6 empty? True
In encoder, is in connection 7 empty? True
In encoder, is in connection 8 empty? True


In [28]:
for i, out_con in enumerate(cnxs.dec_zero_convs_out):
    print(f'In decoder, is out connection {i} empty? {is_empty(out_con)}')

In decoder, is out connection 0 empty? False
In decoder, is out connection 1 empty? True
In decoder, is out connection 2 empty? True
In decoder, is out connection 3 empty? True
In decoder, is out connection 4 empty? True
In decoder, is out connection 5 empty? True
In decoder, is out connection 6 empty? True
In decoder, is out connection 7 empty? True
In decoder, is out connection 8 empty? True


In [29]:
for i, in_con in enumerate(cnxs.dec_zero_convs_in):
    print(f'In decoder, is in connection {i} empty? {is_empty(in_con)}')

There are none, which is as expected

**I seem to have a bug in loading the Heidelberg weights into the diffusers format.** I seem to only save the 1st connections of each type (encoder/decoder x in/out).

Let's **check** if the connection weights are non-empty in the **Heidelberg weights** I used

In [31]:
from safetensors import safe_open

In [32]:
file = '../../../../.hf-cache/CVL-Heidelberg/sdxl_encD_canny_48m.safetensors'

In [33]:
weights_tensors = {}
with safe_open(file, framework='pt', device='cpu') as f:
   for key in f.keys():
       weights_tensors[key] = f.get_tensor(key)

In [35]:
from util import print_as_nested_dict

In [41]:
print_as_nested_dict(weights_tensors, contains='enc_zero_convs_in', lv=3, print_leaf=True)

enc_zero_convs_in
        0
                0	[320, 320, 1, 1]
        1
                0	[320, 320, 1, 1]
        2
                0	[320, 320, 1, 1]
        3
                0	[320, 320, 1, 1]
        4
                0	[640, 640, 1, 1]
        5
                0	[640, 640, 1, 1]
        6
                0	[640, 640, 1, 1]
        7
                0	[1280, 1280, 1, 1]
        8
                0	[1280, 1280, 1, 1]


In [51]:
def is_empty(t):
    if hasattr(t,'weight'): t=t.weight
    return t.flatten().abs().sum()==0

for conv_group in ('enc_zero_convs_in','enc_zero_convs_out','dec_zero_convs_out'):
    for i in range(9):
        con = weights_tensors[f'{conv_group}.{i}.0.weight']
        print(f'In {conv_group}, is out connection {i} empty? {is_empty(con)}')
    print('---')

In enc_zero_convs_in, is out connection 0 empty? False
In enc_zero_convs_in, is out connection 1 empty? False
In enc_zero_convs_in, is out connection 2 empty? False
In enc_zero_convs_in, is out connection 3 empty? False
In enc_zero_convs_in, is out connection 4 empty? False
In enc_zero_convs_in, is out connection 5 empty? False
In enc_zero_convs_in, is out connection 6 empty? False
In enc_zero_convs_in, is out connection 7 empty? False
In enc_zero_convs_in, is out connection 8 empty? False
---
In enc_zero_convs_out, is out connection 0 empty? False
In enc_zero_convs_out, is out connection 1 empty? False
In enc_zero_convs_out, is out connection 2 empty? False
In enc_zero_convs_out, is out connection 3 empty? False
In enc_zero_convs_out, is out connection 4 empty? False
In enc_zero_convs_out, is out connection 5 empty? False
In enc_zero_convs_out, is out connection 6 empty? False
In enc_zero_convs_out, is out connection 7 empty? False
In enc_zero_convs_out, is out connection 8 empty? Fal

Nope, as expected, they're non-empty

Let's **check** if the connection weights are non-empty in **my weights**

In [52]:
file = 'weights/cnxs/diffusion_pytorch_model.safetensors'

In [53]:
weights_tensors = {}
with safe_open(file, framework='pt', device='cpu') as f:
   for key in f.keys():
       weights_tensors[key] = f.get_tensor(key)

In [54]:
print_as_nested_dict(weights_tensors, contains='enc_zero_convs_in', lv=3, print_leaf=True)

enc_zero_convs_in
        0	[320, 320, 1, 1]
        1	[320, 320, 1, 1]
        2	[320, 320, 1, 1]
        3	[320, 320, 1, 1]
        4	[640, 640, 1, 1]
        5	[640, 640, 1, 1]
        6	[640, 640, 1, 1]
        7	[1280, 1280, 1, 1]
        8	[1280, 1280, 1, 1]


In [55]:
for conv_group in ('enc_zero_convs_in','enc_zero_convs_out','dec_zero_convs_out'):
    for i in range(9):
        con = weights_tensors[f'{conv_group}.{i}.weight']
        print(f'In {conv_group}, is out connection {i} empty? {is_empty(con)}')
    print('---')

In enc_zero_convs_in, is out connection 0 empty? False
In enc_zero_convs_in, is out connection 1 empty? True
In enc_zero_convs_in, is out connection 2 empty? True
In enc_zero_convs_in, is out connection 3 empty? True
In enc_zero_convs_in, is out connection 4 empty? True
In enc_zero_convs_in, is out connection 5 empty? True
In enc_zero_convs_in, is out connection 6 empty? True
In enc_zero_convs_in, is out connection 7 empty? True
In enc_zero_convs_in, is out connection 8 empty? True
---
In enc_zero_convs_out, is out connection 0 empty? False
In enc_zero_convs_out, is out connection 1 empty? True
In enc_zero_convs_out, is out connection 2 empty? True
In enc_zero_convs_out, is out connection 3 empty? True
In enc_zero_convs_out, is out connection 4 empty? True
In enc_zero_convs_out, is out connection 5 empty? True
In enc_zero_convs_out, is out connection 6 empty? True
In enc_zero_convs_out, is out connection 7 empty? True
In enc_zero_convs_out, is out connection 8 empty? True
---
In dec_ze

**Yes, they are empty! So a bug it is**

Edit: Bug fixed ✅

Now, the first 10 values for each step in cloud / local are:

Cloud:
```
A1] [ 0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2146]
A2] [-0.0148, -0.3285, -0.0927,  0.0223,  0.0955, -0.2987,  0.0223,  0.0760, -0.0549, -0.7091]
A3] [ 0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2146]
B1] [ 0.3036,  0.2349,  0.2122,  0.2231,  0.4925,  0.0092,  0.3026,  0.3293,  0.0227, -0.1436]
C1] [-0.5792, -0.3964, -0.1935, -0.6617, -0.0559, -0.1097, -0.4058,  0.0838, -0.6492, -1.1216]
D1] [0.0524, 0.0221, 0.0505, 0.0232, 0.0388, 0.0288, 0.0665, 0.0229, 0.0460, 0.0424]
D2] [0.9500]
D3] [ 0.3533,  0.2559,  0.2602,  0.2452,  0.5294,  0.0366,  0.3658,  0.3511,  0.0664, -0.1034]
```

Local:
```
A1]  0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2147]
A2] -0.0147, -0.3285, -0.0927,  0.0223,  0.0955, -0.2987,  0.0222,  0.0760, -0.0548, -0.7091]
A3]  0.1755,  0.1357, -0.3653,  0.4109, -0.2517,  0.6668, -0.6832,  0.1072, -0.3240,  0.2147]
B1]  0.3032,  0.2355,  0.2123,  0.2237,  0.4928,  0.0095,  0.3033,  0.3297,  0.0226, -0.1435]
C1] -0.5794, -0.3963, -0.1934, -0.6616, -0.0560, -0.1097, -0.4058,  0.0836, -0.6489, -1.1214]
D1] 0.0524, 0.0221, 0.0505, 0.0232, 0.0388, 0.0288, 0.0666, 0.0229, 0.0460, 0.0424]
D2] 0.9500]
D3]  0.3530,  0.2566,  0.2602,  0.2458,  0.5296,  0.0368,  0.3665,  0.3515,  0.0663, -0.1032]
```