Here, I do 1 `forward` pass and manually compare each intermediate output with the cloud version.

As the first large difference between local and cloud outputs are after 1st resnet, I'll compare the intermediate outputs of that 1st resnet in detail.

In [1]:
import torch
from torch.testing import assert_close

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps'
device_dtype = torch.float16 if device == 'cuda' else torch.float32

## Load the model

In [18]:
from diffusers import StableDiffusionXLPipeline
from diffusers import EulerDiscreteScheduler
from diffusers.models.controlnetxs import ControlNetXSModel
from diffusers.pipelines.controlnet_xs.pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline

In [19]:
sdxl_pipe = StableDiffusionXLPipeline.from_single_file('weights/sdxl/sd_xl_base_1.0_0.9vae.safetensors').to(device)
cnxs = ControlNetXSModel.from_pretrained('weights/cnxs').to(device)

At the end of __init__, the sigmas are tensor([14.6146, 14.5263, 14.4386, 14.3515, 14.2651]) ...
At the end of __init__, the sigmas are tensor([14.6146, 14.5263, 14.4386, 14.3515, 14.2651]) ...


In [20]:
cnxs.base_model = sdxl_pipe.unet

The example script of Heidelberg manually sets scale_list to 0.95

In [21]:
cnxs.scale_list = cnxs.scale_list * 0. + 0.95
assert cnxs.scale_list[0] == .95

Heidelberg uses `timestep_spacing = 'linspace'` in their scheduler, so let's do that as well

In [22]:
scheduler_cgf = dict(sdxl_pipe.scheduler.config)
scheduler_cgf['timestep_spacing'] = 'linspace'
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

# test it worked
sdxl_pipe.scheduler.set_timesteps(50)
assert sdxl_pipe.scheduler.timesteps[0]==999

# reset
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

At the end of __init__, the sigmas are tensor([14.6146, 14.5263, 14.4386, 14.3515, 14.2651]) ...
timestep_spacing = "leading" and timesteps=[999.      978.61224 958.2245  937.83673 917.449  ] ...
sigmas before interpolation: [0.02916753 0.04131448 0.05068044 0.05861427 0.06563709] ...
sigmas after (linear) interpolation: [14.61464691 12.93677721 11.49164976 10.24291444  9.16035419] ...
At end of `set_timesteps`:
sigmas =  tensor([14.6146, 12.9368, 11.4916, 10.2429,  9.1604]) ...
timesteps = tensor([999.0000, 978.6122, 958.2245, 937.8367, 917.4490]) ...
At the end of __init__, the sigmas are tensor([14.6146, 14.5263, 14.4386, 14.3515, 14.2651]) ...


In [23]:
cnxs_pipe = StableDiffusionXLControlNetXSPipeline(
    vae=sdxl_pipe.vae,
    text_encoder=sdxl_pipe.text_encoder,
    text_encoder_2=sdxl_pipe.text_encoder_2,
    tokenizer=sdxl_pipe.tokenizer,
    tokenizer_2=sdxl_pipe.tokenizer_2,
    unet=sdxl_pipe.unet,
    controlnet=cnxs,
    scheduler=sdxl_pipe.scheduler,
)

___

## Load intermediate outputs of full model

These were computed on the cloud with Heidelberg code

In [30]:
from util_inspect import load_intermediate_outputs, print_metadata, compare_intermediate_results

In [31]:
model_outp_cloud = load_intermediate_outputs('intermediate_output/cloud_debug_log.pkl')
model_outp_local = load_intermediate_outputs('intermediate_output/local_debug_log.pkl')
len(model_outp_cloud),len(model_outp_local)

(72, 72)

In [32]:
compare_intermediate_results(model_outp_cloud, model_outp_local, n=20)

-  | cloud               | local               | equal name? | equal shape? | equal values? | mean abs Δ
   |                     |                     |             |              |    prec=2     |           
--------------------------------------------------------------------------------------------------------
0  | prep   x            | prep   x            | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00000
1  | prep   temb         | prep   temb         | [92m     y     [0m | [91m     n      [0m | [91m      n      [0m |    0.11491
2  | prep   context      | prep   context      | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00053
3  | prep   raw hint     | prep   raw hint     | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00000
4  | prep   guided_hint  | prep   guided_hint  | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00009
-------------------------

## Load intermediate outputs of 1st resnet

In [25]:
resnet_outp_cloud = load_intermediate_outputs('intermediate_output/cloud_resnet.pkl')
resnet_outp_local = load_intermediate_outputs('intermediate_output/local_resnet.pkl')
len(resnet_outp_cloud),len(resnet_outp_local)

(8, 14)

## Compare!

In [33]:
from diffusers.models.controlnetxs import to_sub_blocks
my_subs = to_sub_blocks(cnxs.base_model.down_blocks[0])
assert len(my_subs)==3
first_base_enc_subblock = my_subs[0] # that's the local module (I can't access cloud modules)

In [34]:
prev_h_base_cloud = model_outp_cloud[8].t
prev_h_base_local = model_outp_local[8].t

temb_cloud = model_outp_cloud[1].t
temb_local = model_outp_local[1].t

current_h_base_cloud = model_outp_cloud[10].t
current_h_base_local = model_outp_local[10].t

In [36]:
model_outp_cloud[1].msg, model_outp_cloud[1].shape

('temb', [2, 1280])

In [38]:
temb_cloud[0].shape,temb_cloud[1].shape

(torch.Size([1280]), torch.Size([1280]))

**These should be equal! Why are they not?** 😤🤔

In [43]:
temb_cloud[0]

tensor([ 1.5199, -3.6599,  1.8496,  ..., -0.8640, -1.3085, -1.5696])

In [44]:
temb_cloud[1]

tensor([ 1.4099, -3.6581,  1.9248,  ..., -0.9404, -1.2247, -1.4276])

In [37]:
#assert_close(temb_cloud[0], temb_cloud[1]) # AssertionError

AssertionError: Tensor-likes are not close!

Mismatched elements: 1280 / 1280 (100.0%)
Greatest absolute difference: 0.3006887435913086 at index (790,) (up to 1e-05 allowed)
Greatest relative difference: 64.20502471923828 at index (1161,) (up to 1.3e-06 allowed)

Plausibility check:

In [48]:
#assert_close(resnet_outp_cloud[-1].t, current_h_base_cloud) # AssertionError
#assert_close(resnet_outp_local[-1].t, current_h_base_local) # AssertionError

In [49]:
print_metadata(resnet_outp_cloud)

0 hidden_states start [2, 320, 96, 96]
1 temb start [2, 1280]
2 scale start []
3 hidden_states after norm1/silu/conv1 [2, 320, 96, 96]
4 temb after silu/linear [2, 320]
5 hidden_states after time add [2, 320, 96, 96]
6 hidden_states after norm2/dropout/silu/conv2 [2, 320, 96, 96]
7 hidden_states after norm1/silu/conv1 [2, 320, 96, 96]


In [50]:
print_metadata(resnet_outp_local)

0 hidden_states start [2, 320, 96, 96]
1 temb start [1, 1280]
2 scale start []
3 hidden_states after norm1 [2, 320, 96, 96]
4 hidden_states after silu [2, 320, 96, 96]
5 hidden_states after conv1 [2, 320, 96, 96]
6 temb after silu [1, 1280]
7 temb after linear [1, 320, 1, 1]
8 hidden_states after time add [2, 320, 96, 96]
9 hidden_states after norm2 [2, 320, 96, 96]
10 hidden_states after silu [2, 320, 96, 96]
11 hidden_states after dropout [2, 320, 96, 96]
12 hidden_states after conv2 [2, 320, 96, 96]
13 hidden_states after skip + scale [2, 320, 96, 96]


In [51]:
c2l = dict(enumerate((0,1,2,5,7,8,12,13)))

l2c = {l:c for c,l in c2l.items()}
l2c.update({l: None for l in range(len(resnet_outp_local)) if l not in l2c.keys()})

In [52]:
l2c

{0: 0,
 1: 1,
 2: 2,
 5: 3,
 7: 4,
 8: 5,
 12: 6,
 13: 7,
 3: None,
 4: None,
 6: None,
 9: None,
 10: None,
 11: None}

In [53]:
from torch import tensor
from util_inspect import divider, fmt_bool

def compare_intermediate_results():
    step_map = (0,1,2,5,7,8,12,13)
    
    l,c,es,ev,d = 'local','cloud','equal shape?','equal values?','mean abs Δ'
    print(f'{l:<36} | {c:<48} | {es:<12} | {ev:<13} | {d:<10}')
    c,l,es,ev,d = '','','','prec=2',''
    print(f'{l:<36} | {c:<48} | {es:<12} | {ev:^13} | {d:<10}')
    divider(l=36+3+48+3+12+3+13+3+10)
    for li in range(len(resnet_outp_local)):
        l=resnet_outp_local[li]
        print(f'{li:>2} {l.stage:<14} {l.msg:<18} | ',end='')

        if l2c[li] is not None:
            ci = l2c[li]
            c=resnet_outp_cloud[ci]
        
            print(f'{ci:>2} {c.stage:<14} {c.msg:<30} | ',end='')
            
            eq_shape = False #have_same_shapes(i,c,l,do_print=False)
            eq_vals = False # have_same_values(i,c,l,do_print=False,prec=2)

            delta = c.t-l.t
            if isinstance(delta, float): delta = tensor(delta)
            mean_difference = delta.abs().mean()
                
            print(fmt_bool(eq_shape, '^12')+' | '+fmt_bool(eq_vals, '^13')+' | ', end='')
            print(f'{mean_difference:>10.5f}',end='')
        
        print()

In [55]:
compare_intermediate_results()

local                                | cloud                                            | equal shape? | equal values? | mean abs Δ
                                     |                                                  |              |    prec=2     |           
-----------------------------------------------------------------------------------------------------------------------------------
 0 hidden_states  start              |  0 hidden_states  start                          | [91m     n      [0m | [91m      n      [0m |    0.00088
 1 temb           start              |  1 temb           start                          | [91m     n      [0m | [91m      n      [0m |    0.23343
 2 scale          start              |  2 scale          start                          | [91m     n      [0m | [91m      n      [0m |    0.00000
 3 hidden_states  after norm1        | 
 4 hidden_states  after silu         | 
 5 hidden_states  after conv1        |  3 hidden_states  after norm1/silu/

Something is wrong with the time information or the time projection

In [56]:
temb_cloud.shape, temb_local.shape

(torch.Size([2, 1280]), torch.Size([1, 1280]))

In [36]:
#assert_close(temb_cloud[0], temb_cloud[1]) # AssertionError

In [37]:
#assert_close(temb_cloud[0], temb_local[0]) # AssertionError

In [38]:
#assert_close(temb_cloud[1], temb_local[0]) # AssertionError

In [39]:
cnxs.learn_embedding

True

## Run 1 step locally

In [24]:
import torch
import random
import numpy as np
import cv2
from diffusers.utils import load_image
import matplotlib.pyplot as plt

class CannyDetector:
    def __call__(self, img, low_threshold, high_threshold):
        return cv2.Canny(img, low_threshold, high_threshold)

def get_canny_edges(image, threshold=(100, 250)):
    image = np.array(image).astype(np.uint8)
    edges = CannyDetector()(image, *threshold)  # original sized greyscale edges
    edges = edges / 255.
    return edges

def seed_everything(seed):
    # paper used deprecated `seed_everything` from pytorch lightning
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)

RANDOM_SEED_IN_PAPER = 1999158951

In [25]:
latents_sdxl_cloud = torch.load('latents_cloud_no_control.pth', map_location=torch.device(device))
rand_from_cloud = latents_sdxl_cloud[0] / 14.6146

In [26]:
prompt = 'cinematic, shoe in the streets, made from meat, photorealistic shoe, highly detailed'
neg_prompt = 'lowres, bad anatomy, worst quality, low quality'

image = load_image('input_images/shoe_cloud.png')
edges = get_canny_edges(image)

edges_tensor = torch.tensor(edges)
three_edges = torch.stack((edges_tensor,edges_tensor,edges_tensor))
three_edges.shape

torch.Size([3, 768, 768])

In [28]:
cnxs_pipe.controlnet.DEBUG_LOG_by_Umer = True
cnxs_pipe.controlnet.DEBUG_LOG_by_Umer_file = 'intermediate_output/local_debug_log.pkl'
cnxs_pipe.controlnet.DEBUG_LOG_by_Umer

True

In [29]:
seed_everything(RANDOM_SEED_IN_PAPER)
cnxs_pipe(prompt, negative_prompt=neg_prompt,image=three_edges, latents=rand_from_cloud)

timestep_spacing = "leading" and timesteps=[999.      978.61224 958.2245  937.83673 917.449  ] ...
sigmas before interpolation: [0.02916753 0.04131448 0.05068044 0.05861427 0.06563709] ...
sigmas after (linear) interpolation: [14.61464691 12.93677721 11.49164976 10.24291444  9.16035419] ...
At end of `set_timesteps`:
sigmas =  tensor([14.6146, 12.9368, 11.4916, 10.2429,  9.1604], device='mps:0') ...
timesteps = tensor([999.0000, 978.6122, 958.2245, 937.8367, 917.4490], device='mps:0') ...
Passed in latents:  tensor([ 1.3333,  0.5155,  0.4647, -0.5344,  1.0102], device='mps:0')
initial_unscaled_latents:  tensor([ 1.3333,  0.5155,  0.4647, -0.5344,  1.0102], device='mps:0')
latents:  tensor([19.4858,  7.5339,  6.7910, -7.8098, 14.7639], device='mps:0')


  0%|          | 0/50 [00:00<?, ?it/s]

timesteps = tensor(999., device='mps:0')
timesteps = tensor([999.], device='mps:0')
t_emb.shape = [1, 320]
learn_embedding = True
t_emb.shape = [1, 1280]


RuntimeError: Debug Log saved successfully