The images created by diffusers CNXS for **SD2.1** and canny input are clearly not right. They are far too brown.

Let's compare the Heidelberg and diffusers versions on subblock level.

**Edit:** There were 3 errors, which I've fixed:
- The random noise given was not N(0,1) 
- The time schedule was wrong
- The group norm sizes were wrong

The intermediate results on level subblock minus 1 look good. In `Compare intermediate results -- 15.ipynb`, I'll look at the intermediate results on subblock level.

In [1]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # needed to make torch deterministic

In [2]:
import torch
from torch.testing import assert_close
from torch import allclose, nn, tensor
torch.set_printoptions(linewidth=200, precision=3, sci_mode=False)

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'mps'
device_dtype = torch.float16 if device == 'cuda' else torch.float32

## Load logs

In [4]:
from diffusers.umer_debug_logger import UmerDebugLogger

Info: `UmerDebugLogger` created. This is a logging class that will be deleted when the PR to integrate ControlNet-XS is done.


In [6]:
cloud_cuda = UmerDebugLogger.load_log_objects_from_dir('logs/cloud')
local_cuda = UmerDebugLogger.load_log_objects_from_dir('logs/local_cuda')

print(len(cloud_cuda), len(local_cuda))

for i, (c,l) in enumerate(zip(cloud_cuda, local_cuda)):
    if c.msg!=l.msg: print(f'{i:<3}{c.msg:>20}{l.msg:>20}')

539 539


## Compare intermediate results

In [7]:
def mae(t1,t2):
    assert t1.shape==t2.shape
    return (t1-t2).abs().mean()

In [8]:
from functools import partial
from util_inspect import fmt_bool

def compare_intermediate_results(n=None,n_start=0,prec=5, compare_prec=3, ignore_base=False):
    if n is None: n=max(len(cloud_cuda), len(local_cuda))

    print(f'{"":<3} | {"block":<21} | {"name":<20} | {"shape":<20} | {"same names?":<12} | {"same shapes?":<12} | {"same values?":<12} | {"Δ cuda local -> cloud":<20}')
    print(f'{"":<3} | {"":<21} | {"":<20} | {"":<20} | {"":<12} | {"":<12} | {"prec="+str(compare_prec):^12} | {"prec="+str(prec):^20}')

    def calc_total_len(lens): return sum(lens)+3*len(lens)-1
    total_len = calc_total_len((3,20,20,20,12,12,12,20))

    line = partial(
        lambda txt, width: print(txt * (width//len(txt))),
        width=total_len
    )
    
    labels = []
    def add_label(lbs, ctrl=True):
        if not isinstance(lbs, (list, tuple)): lbs = [lbs]
        for l in lbs:
            labels.append(('Base',l))
        for l in lbs: 
            if ctrl: labels.append(('Ctrl',l))

    # # prep
    labels.append(('Prep','Prep'))
    # # conv_in
    add_label('ConvIn')
    # # down
    # 1
    add_label(('ResBlock d1.1', 'AttnBlock d1.1'))
    add_label(('ResBlock d1.1', 'AttnBlock d1.1'))    
    add_label('Conv d1')
    # 2
    add_label(('ResBlock d2.1', 'AttnBlock d2.1'))
    add_label(('ResBlock d2.2', 'AttnBlock d2.2'))
    add_label('Conv d2')
    # 3
    add_label(('ResBlock d3.1', 'AttnBlock d3.1'))
    add_label(('ResBlock d3.2', 'AttnBlock d3.2')) 
    # # mid
    add_label(('ResBlock m1', 'AttnBlock m', 'ResBlock mid2'))
    # # up
    for _ in range(1000): add_label('DONT CARE', ctrl=False)
    
    line('#')
    bc,block=labels.pop(0)
    for i in range(n_start,n):
        cc,lc = cloud_cuda[i], local_cuda[i]
                
        eq_name = cc.msg==lc.msg
        eq_shape = cc.shape==lc.shape

        if eq_shape:
            eq_vals = torch.allclose(cc.t,lc.t,atol=10**-compare_prec)
            mae_2 = mae(lc.t,cc.t) 
            mae_2 = ("{:>20."+str(prec)+"f}").format(mae_2)
        else:
            eq_vals,mae_2=False,'inf'
        
        if not (ignore_base and bc=='Base'):
            print(f'{i+1:<3} | {bc:<4} | {block:<14} | {cc.msg:<20} | {cc.shape:>20} | {fmt_bool(eq_name, "^12")} | {fmt_bool(eq_shape, "^12")} | {fmt_bool(eq_vals, "^12")} | {mae_2}')

        if cc.msg in ('conv_in.output', 'prep.guided_hint', 'add conv_shortcut','conv','proj_out'):
            line('=')
            bc,block=labels.pop(0)
        elif cc.msg in ('add ff','proj_in'): line('- ')

In [10]:
compare_intermediate_results(compare_prec=3, prec=3, ignore_base=False)

    | block                 | name                 | shape                | same names?  | same shapes? | same values? | Δ cuda local -> cloud
    |                       |                      |                      |              |              |    prec=3    |        prec=3       
##############################################################################################################################################
1   | Prep | Prep           | prep.x               |       [2, 4, 64, 64] | [92m     y      [0m | [92m     y      [0m | [92m     y      [0m |                0.000
2   | Prep | Prep           | prep.temb            |            [2, 1280] | [92m     y      [0m | [92m     y      [0m | [92m     y      [0m |                0.000
3   | Prep | Prep           | prep.context         |        [2, 77, 1024] | [92m     y      [0m | [92m     y      [0m | [92m     y      [0m |                0.000
4   | Prep | Prep           | prep.raw_hint        |     [2, 3, 5

In [9]:
c,l = cloud_cuda, local_cuda

In [10]:
### Base

c_inp_base,  l_inp_base  = c[7].t,  l[7].t   # input into 1st resnet (for base)
c_outp_base, l_outp_base = c[8].t,  l[8].t   # output of 1st resnet norm (for base)

for a,b in [
    (c_inp_base,  l_inp_base),
    (c_outp_base, l_outp_base),
]:
    print(a.shape, b.shape, mae(a, b))

print()

### Control

c_inp_ctrl,  l_inp_ctrl  = c[27].t, l[27].t  # input into 1st resnet (for ctrl)
c_outp_ctrl, l_outp_ctrl = c[28].t, l[28].t  # output of 1st resnet norm (for ctrl)

for a,b in [
    (c_inp_ctrl,  l_inp_ctrl), 
    (c_outp_ctrl, l_outp_ctrl)
]:
    print(a.shape, b.shape, mae(a, b))

torch.Size([2, 320, 64, 64]) torch.Size([2, 320, 64, 64]) tensor(0.002)
torch.Size([2, 320, 64, 64]) torch.Size([2, 320, 64, 64]) tensor(0.002)

torch.Size([2, 324, 64, 64]) torch.Size([2, 324, 64, 64]) tensor(0.010)
torch.Size([2, 324, 64, 64]) torch.Size([2, 324, 64, 64]) tensor(0.011)


In [11]:
from torch import nn

assert 320%32==324%4==324%27==0

norm_base        = nn.GroupNorm(num_groups=32, num_channels=320)
norm_ctrl_is     = nn.GroupNorm(num_groups= 4, num_channels=324)
norm_ctrl_should = nn.GroupNorm(num_groups=27, num_channels=324)

In [12]:
# Should: base_norm(c_inp) == c_base_norm
mae(norm_base(c_inp_base), c_outp_base)

tensor(0.468, grad_fn=<MeanBackward0>)

In [13]:
# Should: base_norm(c_inp) == c_base_norm
mae(norm_base(l_inp_base), l_outp_base)

tensor(0.468, grad_fn=<MeanBackward0>)

GroupNorm saves statistics (mean, std) during training. The GroupNorms I created above are empty, and thus produce errors.

Let's load the GroupNorms from Heidelberg-CNXS-SD21

In [14]:
# import scripts.control_utils as cu

# torch.use_deterministic_algorithms(True)

# path_to_config = 'cnxs_config/sd/sd21_encD_canny_14m.yaml'

# model = cu.create_model(path_to_config).to('cuda')

# norm_base = model.model.diffusion_model.input_blocks[1][0].in_layers[0].to('cpu')
# norm_base
# # > GroupNorm32(32, 320, eps=1e-05, affine=True)

# # # Should: base_norm(c_inp) == c_base_norm
# mae(norm_base(c_inp_base), c_outp_base)
# # > tensor(    0.000, grad_fn=<MeanBackward0>)

# # # Should: base_norm(c_inp) == c_base_norm
# mae(norm_base(l_inp_base), l_outp_base)
# # > tensor(    0.000, grad_fn=<MeanBackward0>)

I could reproduce the norm output for base (for cuda & local). 

Let's now load the GroupNorms from Diffusers-CNXS-SD21

In [15]:
from diffusers import StableDiffusionPipeline
from diffusers import ControlNetXSModel
from diffusers import StableDiffusionControlNetXSPipeline

sd_pipe = StableDiffusionPipeline.from_single_file('weights/sd/sd21/v2-1_512-ema-pruned.ckpt')
cnxs = ControlNetXSModel.from_pretrained('weights/cnxs_diffusers/sd21-canny')

cnxs_pipe = StableDiffusionControlNetXSPipeline(
    vae=sd_pipe.vae,
    text_encoder=sd_pipe.text_encoder,
    tokenizer=sd_pipe.tokenizer,
    unet=sd_pipe.unet,
    controlnet=cnxs,
    scheduler=sd_pipe.scheduler,
    safety_checker=sd_pipe.safety_checker,
    feature_extractor=sd_pipe.feature_extractor
).to('cpu')

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
You have disabled the safety checker for <class 'diffusers.pipelines.controlnet_xs.pipeline_controlnet_xs.StableDiffusionControlNetXSPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do

In [16]:
norm_ctrl = cnxs_pipe.controlnet.control_model.down_blocks[0].resnets[0].norm1.to('cpu')
norm_ctrl.num_groups = 4 # as currently is
norm_ctrl

# > GroupNorm(4, 324, eps=1e-05, affine=True)

# # Should: base_norm(c_inp) == c_base_norm
mae(norm_ctrl(l_inp_ctrl), l_outp_ctrl)
# > tensor(    0.000, grad_fn=<MeanBackward0>)
# # I can recreate the error exactly

norm_ctrl.num_groups = 27 # as should be
mae(norm_ctrl(l_inp_ctrl), l_outp_ctrl)
# > tensor(0.217, grad_fn=<MeanBackward0>)

norm_ctrl.num_groups = 27
# Should: base_cltr(c_inp) == c_ctrl_norm
mae(norm_ctrl(c_inp_ctrl), c_outp_ctrl)
# > tensor(    0.000, grad_fn=<MeanBackward0>)

tensor(    0.000, grad_fn=<MeanBackward0>)