The current issue seems to be applying the ctrl subblock. Let's analyze it.

In [1]:
import torch
from torch.testing import assert_close
from torch import allclose, nn, tensor
torch.set_printoptions(linewidth=200)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps'
device_dtype = torch.float16 if device == 'cuda' else torch.float32

## Load the model

In [3]:
from diffusers import StableDiffusionXLPipeline
from diffusers import EulerDiscreteScheduler
from diffusers.models.controlnetxs import ControlNetXSModel
from diffusers.pipelines.controlnet_xs.pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline

In [4]:
sdxl_pipe = StableDiffusionXLPipeline.from_single_file('weights/sdxl/sd_xl_base_1.0_0.9vae.safetensors').to(device)
cnxs = ControlNetXSModel.from_pretrained('weights/cnxs').to(device)

In [5]:
cnxs.base_model = sdxl_pipe.unet

The example script of Heidelberg manually sets scale_list to 0.95

In [6]:
cnxs.scale_list = cnxs.scale_list * 0. + 0.95
assert cnxs.scale_list[0] == .95

Heidelberg uses `timestep_spacing = 'linspace'` in their scheduler, so let's do that as well

In [7]:
scheduler_cgf = dict(sdxl_pipe.scheduler.config)
scheduler_cgf['timestep_spacing'] = 'linspace'
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

# test it worked
sdxl_pipe.scheduler.set_timesteps(50)
assert sdxl_pipe.scheduler.timesteps[0]==999

# reset
sdxl_pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_cgf)

sigmas after (linear) interpolation: [14.61464691 12.93677721 11.49164976 10.24291444  9.16035419] ...


In [8]:
cnxs_pipe = StableDiffusionXLControlNetXSPipeline(
    vae=sdxl_pipe.vae,
    text_encoder=sdxl_pipe.text_encoder,
    text_encoder_2=sdxl_pipe.text_encoder_2,
    tokenizer=sdxl_pipe.tokenizer,
    tokenizer_2=sdxl_pipe.tokenizer_2,
    unet=sdxl_pipe.unet,
    controlnet=cnxs,
    scheduler=sdxl_pipe.scheduler,
)

## Compare intermediate results

In [9]:
from util_inspect import load_intermediate_outputs, print_metadata, compare_intermediate_results

In [10]:
model_outp_cloud = load_intermediate_outputs('intermediate_output/cloud_debug_log.pkl')
model_outp_local = load_intermediate_outputs('intermediate_output/local_debug_log.pkl')
len(model_outp_cloud),len(model_outp_local)

(72, 72)

In [11]:
compare_intermediate_results(model_outp_cloud, model_outp_local, n=8+8*4, n_start=8,prec=5, compare_prec=3)

-  | cloud               | local               | equal name? | equal shape? | equal values? | mean abs Δ
   |                     |                     |             |              |    prec=3     |     prec=5
--------------------------------------------------------------------------------------------------------
8  | enc    h_base       | enc    h_base       | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00008   added ctrl -> base
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
9  | enc    h_ctrl       | enc    h_ctrl       | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00006   concatted base -> ctrl
10 | enc    h_base       | enc    h_base       | [92m     y     [0m | [92m     y      [0m | [91m      n      [0m |    0.00226   applied base subblock
11 | enc    h_ctrl       | enc    h_ctrl       | [92m     y     [0m | [92m     y      [0m | [92m      y      

## Analyze difference after application of `m_ctrl`

In [12]:
from diffusers.models.controlnetxs import to_sub_blocks
ctrl_down_subblocks = to_sub_blocks(cnxs.control_model.down_blocks)

In [13]:
len(ctrl_down_subblocks)

8

To get a better understanding of the downblocks, let's print each

In [15]:
from util import cls_name

In [16]:
for i, b in enumerate(ctrl_down_subblocks):
    print(i,':',' '.join(cls_name(m) for m in b))

0 : ResnetBlock2D
1 : ResnetBlock2D
2 : Downsample2D
3 : ResnetBlock2D Transformer2DModel
4 : ResnetBlock2D Transformer2DModel
5 : Downsample2D
6 : ResnetBlock2D Transformer2DModel
7 : ResnetBlock2D Transformer2DModel


**Q:** Did I maybe also load their weights incorrectly?

In [17]:
first_resblock = cnxs.control_model.down_blocks[0].resnets[0]
first_resblock

ResnetBlock2D(
  (norm1): GroupNorm(32, 352, eps=1e-05, affine=True)
  (conv1): LoRACompatibleConv(352, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=32, bias=True)
  (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (conv2): LoRACompatibleConv(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (nonlinearity): SiLU()
  (conv_shortcut): LoRACompatibleConv(352, 32, kernel_size=(1, 1), stride=(1, 1))
)

In [18]:
param_parts = ('conv1','time_emb_proj','conv2','conv_shortcut')

In [19]:
from util import cls_name

In [20]:
def generate_transformer_keys(block_num):
    keys = []
    for i in range(block_num):
        keys.extend([
            f'transformer_blocks.{i}.attn1.to_q',
            f'transformer_blocks.{i}.attn1.to_k',
            f'transformer_blocks.{i}.attn1.to_v',
            f'transformer_blocks.{i}.attn1.to_out.0',
            f'transformer_blocks.{i}.attn2.to_q',
            f'transformer_blocks.{i}.attn2.to_k',
            f'transformer_blocks.{i}.attn2.to_v',
            f'transformer_blocks.{i}.attn2.to_out.0',
            f'transformer_blocks.{i}.ff.net.0.proj',
            f'transformer_blocks.{i}.ff.net.2'
        ])
    return keys

param_parts = {
    'ResnetBlock2D': ['conv1','time_emb_proj','conv2','conv_shortcut'],
    'Downsample2D': ['conv'],
    'Transformer2DModel': [
        'proj_in',
        *generate_transformer_keys(10),
        'proj_out'
    ]
}

In [21]:
len(param_parts['ResnetBlock2D']), len(param_parts['Downsample2D']), len(param_parts['Transformer2DModel'])

(4, 1, 102)

In [22]:
def attr(o, attr_str):
    attrs = attr_str.split('.')
    for a in attrs:
        if a.isdigit(): a = int(a)
        o = o[a] if isinstance(a, int) else getattr(o,a)
    return o

def get_tf_idx(param_name):
    if not 'transformer_blocks' in param_name: return -1
    transformer_blocks, idx, *_ = param_name.split('.')
    assert transformer_blocks=='transformer_blocks'
    return int(idx)

assert get_tf_idx('time_emb_proj')==-1
assert get_tf_idx('transformer_blocks.0.ff.net.2')==0
assert get_tf_idx('transformer_blocks.8.attn1.to_out.0')==8

In [23]:
for i,b in enumerate(ctrl_down_subblocks):
    print(f'>>>>>> Subblock {i}:')
    for m in b:
        print('>>',cls_name(m))
        for p in param_parts[cls_name(m)]:
            if i<3 and get_tf_idx(p)>=0: continue # 1st level has no attentions
            if i<6 and get_tf_idx(p)>=2: continue # 2nd level attentions with only 2 transformers
            first_values = attr(m,p).weight.flatten().cpu().detach()[:10]
            p = p.replace('transformer_blocks','tf')
            print(f'{p:<30} {first_values}')
        print()

>>>>>> Subblock 0:
>> ResnetBlock2D
conv1                          tensor([-0.0225,  0.0219, -0.0374,  0.0109,  0.0028,  0.0212,  0.0164, -0.0133, -0.0110, -0.0135])
time_emb_proj                  tensor([-0.0152,  0.0160, -0.0025,  0.0155, -0.0167,  0.0045, -0.0243, -0.0039,  0.0127, -0.0094])
conv2                          tensor([-0.0180, -0.0231, -0.0207, -0.0339, -0.0294, -0.0046, -0.0047, -0.0194,  0.0065,  0.0078])
conv_shortcut                  tensor([ 0.0269,  0.0493, -0.0816,  0.0297, -0.0484,  0.0082,  0.0849, -0.0301, -0.1588,  0.0405])

>>>>>> Subblock 1:
>> ResnetBlock2D
conv1                          tensor([ 0.0023, -0.0076, -0.0082,  0.0090,  0.0185,  0.0055, -0.0044, -0.0300,  0.0023,  0.0091])
time_emb_proj                  tensor([-0.0142,  0.0132, -0.0026,  0.0191, -0.0068,  0.0182, -0.0076,  0.0256,  0.0187, -0.0159])
conv2                          tensor([ 0.0116,  0.0100, -0.0316, -0.0018,  0.0002, -0.0166, -0.0173,  0.0020, -0.0167, -0.0232])
conv_shortcut    

**A:** No, none of the downblocks is empty.

Also, the (non-attention) weights on cloud are exactly identical:
```
Subblock 0 - ResBlock:
conv1                tensor([-0.0225,  0.0219, -0.0374,  0.0109,  0.0028,  0.0212,  0.0164, -0.0133, -0.0110, -0.0135])
time_emb_proj        tensor([-0.0152,  0.0160, -0.0025,  0.0155, -0.0167,  0.0045, -0.0243, -0.0039,  0.0127, -0.0094])
conv2                tensor([-0.0180, -0.0231, -0.0207, -0.0339, -0.0294, -0.0046, -0.0047, -0.0194,  0.0065,  0.0078])
conv_shortcut        tensor([ 0.0269,  0.0493, -0.0816,  0.0297, -0.0484,  0.0082,  0.0849, -0.0301, -0.1588,  0.0405])

Subblock 1 - ResBlock:
conv1                tensor([ 0.0023, -0.0076, -0.0082,  0.0090,  0.0185,  0.0055, -0.0044, -0.0300,  0.0023,  0.0091])
time_emb_proj        tensor([-0.0142,  0.0132, -0.0026,  0.0191, -0.0068,  0.0182, -0.0076,  0.0256,  0.0187, -0.0159])
conv2                tensor([ 0.0116,  0.0100, -0.0316, -0.0018,  0.0002, -0.0166, -0.0173,  0.0020, -0.0167, -0.0232])
conv_shortcut        tensor([ 0.0024, -0.0124, -0.0311, -0.0614, -0.0990, -0.0994, -0.0391, -0.1036,  0.0628,  0.0304])

Subblock 2 - Downsample:
conv                 tensor([ 0.0014, -0.0107, -0.0256, -0.0322, -0.0727, -0.0276, -0.0175, -0.0603, -0.0378, -0.0008])

Subblock 3 - ResBlock:
conv1                tensor([ 0.0038, -0.0225, -0.0137,  0.0118,  0.0261,  0.0445,  0.0046,  0.0254,  0.0123,  0.0196])
time_emb_proj        tensor([ 0.0029,  0.0281,  0.0008,  0.0192,  0.0290,  0.0038, -0.0110,  0.0048, -0.0119,  0.0320])
conv2                tensor([-0.0015,  0.0039,  0.0196,  0.0022,  0.0137,  0.0201,  0.0038,  0.0015,  0.0076, -0.0116])
conv_shortcut        tensor([ 0.0105,  0.0405, -0.0068, -0.0370,  0.0426, -0.0199,  0.0712, -0.0207, -0.0158,  0.0146])

Subblock 4 - ResBlock:
conv1                tensor([ 0.0125,  0.0231,  0.0078,  0.0088,  0.0212, -0.0005,  0.0063, -0.0082, -0.0049, -0.0073])
time_emb_proj        tensor([-0.0077,  0.0140, -0.0198, -0.0020,  0.0209,  0.0031,  0.0255,  0.0376,  0.0209,  0.0419])
conv2                tensor([ 0.0068,  0.0122,  0.0078, -0.0034,  0.0013,  0.0023, -0.0065,  0.0048,  0.0061, -0.0060])
conv_shortcut        tensor([ 0.0409,  0.0009,  0.0166,  0.0155, -0.0032,  0.0379, -0.0255,  0.0058, -0.0301,  0.0224])

Subblock 5 - Downsample:
conv                 tensor([-0.0161, -0.0169, -0.0305, -0.0299, -0.0325, -0.0410, -0.0279, -0.0187, -0.0390,  0.0234])

Subblock 6 - ResBlock:
conv1                tensor([ 0.0206,  0.0154,  0.0105,  0.0079,  0.0243,  0.0134,  0.0070,  0.0179,  0.0059, -0.0110])
time_emb_proj        tensor([-1.9000e-02,  1.4582e-02, -2.6406e-02,  2.2612e-02,  1.5059e-05,  3.3383e-02,  1.0330e-02, -1.6467e-03,  8.3065e-03, -2.8895e-03])
conv2                tensor([ 0.0040,  0.0018,  0.0039, -0.0016, -0.0038, -0.0039, -0.0098, -0.0114, -0.0120,  0.0087])
conv_shortcut        tensor([-0.0092,  0.0251,  0.0145, -0.0126, -0.0410,  0.0210, -0.0353, -0.0262, -0.0098, -0.0207])

Subblock 7 - ResBlock:
conv1                tensor([-9.8836e-05,  1.3444e-02,  1.0174e-02,  4.2135e-03,  5.0547e-03,  4.8702e-03, -7.1490e-03,  6.3899e-03,  8.2130e-03,  5.3117e-03])
time_emb_proj        tensor([-0.0073,  0.0077, -0.0156,  0.0150, -0.0060,  0.0162,  0.0031, -0.0053,  0.0093, -0.0236])
conv2                tensor([ 0.0107, -0.0005, -0.0032,  0.0017, -0.0033, -0.0045,  0.0066,  0.0030,  0.0046, -0.0098])
conv_shortcut        tensor([-0.0185, -0.0004, -0.0143,  0.0176,  0.0277,  0.0025, -0.0249, -0.0079,  0.0086,  0.0037])
```

Hmmmmmm....

Sanity check: Is the local bias non-empty?

Edit: Yes

In [101]:
for i,m in enumerate(ctrl_down_subblocks):
    m = m[0] # unwrap EmbedSequential
    print(f'Subblock {i} - {cls_name(m)}:')
    for p in param_parts[cls_name(m)]:
        first_values = getattr(m,p).bias.flatten().cpu().detach()[:10]
        print(f'{p:<20} {first_values}')
    print()

Subblock 0 - ResnetBlock2D:
conv1                tensor([-0.0062,  0.0141, -0.0095, -0.0061,  0.0170, -0.0133, -0.0058, -0.0010, -0.0014,  0.0013])
time_emb_proj        tensor([-0.0074, -0.0122,  0.0255,  0.0128,  0.0008, -0.0081,  0.0084, -0.0132, -0.0229,  0.0129])
conv2                tensor([-0.0061, -0.0048, -0.0087,  0.0158,  0.0009, -0.0058, -0.0021, -0.0009, -0.0136,  0.0085])
conv_shortcut        tensor([-0.0050, -0.0129, -0.0212,  0.0425, -0.0277,  0.0116, -0.0132,  0.0346,  0.0079, -0.0265])

Subblock 1 - ResnetBlock2D:
conv1                tensor([-0.0027,  0.0005, -0.0158,  0.0097, -0.0110,  0.0160,  0.0056,  0.0125,  0.0093, -0.0074])
time_emb_proj        tensor([ 0.0102,  0.0091,  0.0214,  0.0138, -0.0254, -0.0157, -0.0252, -0.0064, -0.0081,  0.0197])
conv2                tensor([ 0.0032,  0.0020, -0.0006, -0.0114,  0.0030,  0.0048,  0.0055,  0.0015,  0.0002, -0.0041])
conv_shortcut        tensor([ 0.0400,  0.0060, -0.0259, -0.0382,  0.0090,  0.0439,  0.0229, -0.0307, -0

___

Sanity check: base and ctrl have same number of subblocks

In [26]:
base_down_subblocks = to_sub_blocks(cnxs.base_model.down_blocks)

In [27]:
assert len(base_down_subblocks)==len(ctrl_down_subblocks)
len(base_down_subblocks)

8

Sanity check: base and ctrl have same number of transformers

In [28]:
for i, b in enumerate(base_down_subblocks):
    if len(b)==1: continue # no 2nd part, which would be the attention
    print('base subblock',i,'has',len(b[1].transformer_blocks),'transformers')
print()
for i, b in enumerate(ctrl_down_subblocks):
    if len(b)==1: continue # no 2nd part, which would be the attention
    print('ctrl subblock',i,'has',len(b[1].transformer_blocks),'transformers')

base subblock 3 has 2 transformers
base subblock 4 has 2 transformers
base subblock 6 has 10 transformers
base subblock 7 has 10 transformers

ctrl subblock 3 has 2 transformers
ctrl subblock 4 has 2 transformers
ctrl subblock 6 has 10 transformers
ctrl subblock 7 has 10 transformers


Okay, I've saved a detailled debug log (at level subblock minus 1). Let's confirm none of those is empty.

In [60]:
import pickle

with open('intermediate_output/subblock-minus-1/cloud_detailled_debug_log.pkl', 'rb') as f:
    dlog_c = pickle.load(f)
with open('intermediate_output/subblock-minus-1/local_detailled_debug_log.pkl', 'rb') as f:
    dlog_l = pickle.load(f)

In [61]:
torch.set_printoptions(precision=3)

In [62]:
for name, t in dlog_l:
    print(f'{name:<20}{str(list(t.shape)):<20}{t.flatten()[:10]}')

conv1               [2, 32, 96, 96]     tensor([-0.828, -0.465, -0.597, -0.728, -0.992,  0.050, -0.500, -1.920, -1.507, -0.325])
add time_emb_proj   [2, 32, 96, 96]     tensor([-0.434, -0.070, -0.202, -0.334, -0.597,  0.445, -0.106, -1.526, -1.112,  0.070])
conv2               [2, 32, 96, 96]     tensor([-0.201, -0.217, -0.178, -0.103, -0.235, -0.112, -0.032,  0.107, -0.073, -0.313])
add conv_shortcut   [2, 32, 96, 96]     tensor([-0.579, -0.396, -0.193, -0.662, -0.056, -0.110, -0.406,  0.084, -0.649, -1.121])
conv1               [2, 32, 96, 96]     tensor([-0.493, -1.457, -1.604, -1.583, -1.943, -1.831, -2.015, -1.617, -1.150, -1.342])
add time_emb_proj   [2, 32, 96, 96]     tensor([ 0.781, -0.184, -0.331, -0.310, -0.670, -0.557, -0.742, -0.344,  0.124, -0.068])
conv2               [2, 32, 96, 96]     tensor([-0.110, -0.188,  0.008, -0.117,  0.199, -0.032,  0.062,  0.046,  0.088, -0.130])
add conv_shortcut   [2, 32, 96, 96]     tensor([-0.587, -0.252,  0.140, -0.243,  0.250, -0.864, -

Local intermediate results are not empty

In [52]:
for name, t in dlog_c:
    print(f'{name:<20}{str(list(t.shape)):<20}{t.flatten()[:10]}')

conv1               [2, 32, 96, 96]     tensor([-0.828, -0.465, -0.597, -0.728, -0.992,  0.050, -0.500, -1.920, -1.507, -0.325])
add time_emb_proj   [2, 32, 96, 96]     tensor([-0.418, -0.055, -0.187, -0.318, -0.582,  0.460, -0.090, -1.510, -1.097,  0.085])
conv2               [2, 32, 96, 96]     tensor([-0.201, -0.217, -0.178, -0.103, -0.235, -0.112, -0.033,  0.107, -0.073, -0.313])
add conv_shortcut   [2, 32, 96, 96]     tensor([-0.579, -0.396, -0.193, -0.662, -0.056, -0.110, -0.406,  0.084, -0.649, -1.122])
conv1               [2, 32, 96, 96]     tensor([-0.489, -1.452, -1.598, -1.578, -1.935, -1.825, -2.010, -1.613, -1.147, -1.336])
add time_emb_proj   [2, 32, 96, 96]     tensor([ 0.798, -0.165, -0.311, -0.291, -0.648, -0.538, -0.723, -0.326,  0.140, -0.049])
conv2               [2, 32, 96, 96]     tensor([-0.109, -0.187,  0.009, -0.117,  0.200, -0.033,  0.062,  0.047,  0.089, -0.130])
add conv_shortcut   [2, 32, 96, 96]     tensor([-0.586, -0.251,  0.140, -0.243,  0.251, -0.865, -

Cloud intermediate results are also not empty

___

Okay, now let's compare local and cloud intermediate outputs (at level subblock minus 1):

In [147]:
from functools import partial
from util_inspect import fmt_bool

def compare_intermediate_results(n=None,n_start=0,prec=5, compare_prec=2):
    if n is None: n=max(len(dlog_c),len(dlog_l))
    i,lv,b,c,l,en,es,ev,d,stats = '-','level','block','cloud','local','equal name?','equal shape?','equal values?','mean abs Δ','mean ± std'
    print(f'{i:<3} | {lv:<5} | {b:<5} | {c:<19} | {l:<19} | {en:<11} | {es:<12} | {ev:<13} | ' + ('{:>'+str(prec+5)+'}').format(d) + f' | {stats:>12}')
    i,lv,b,c,l,en,es,ev,d,stats = '','','','','','','','prec='+str(compare_prec),'prec='+str(prec),''
    print(f'{i:<3} | {lv:<5} | {b:<5} | {c:<19} | {l:<19} | {en:<11} | {es:<12} | {ev:^13} | ' + ('{:>'+str(prec+5)+'}').format(d) + f' | {stats:>12}')
    total_len = 3+3+5+3+5+3+19+3+19+3+11+3+12+3+13+3+(prec+5)+3+12

    line = partial(
        lambda txt, width: print(txt * (width//len(txt))),
        width=total_len
    )
    
    line('=')

    lv,block=1,1
    for i in range(n_start,n):
        (cn,ct),(ln,lt)=dlog_c[i],dlog_l[i]
        eq_name = cn==ln
        eq_shape = ct.shape==lt.shape
        eq_vals = torch.allclose(ct,lt,atol=10**-compare_prec)
        #ct,lt = broadcast(c.t,l.t)
        mae = (ct-lt).abs().mean()

        mean,std = ct.mean(),ct.std()
        
        print(f'{i:<3} | {lv:^5} | {block:^5} | {cn:<19} | {ln:<19} | ', end='')
        print(fmt_bool(eq_name, '^11')+' | '+fmt_bool(eq_shape, '^12')+' | '+fmt_bool(eq_vals, '^13')+' | ', end='')
        print(('{:>'+str(prec+5)+'.'+str(prec)+'f}').format(mae)+' | ', end='')
        print(f'{ct.mean():>5.2f} ± {ct.std():>3.2f}')

        if cn=='conv':
            line('=')
            lv += 1
            block = 1
        if cn in ('add conv_shortcut','proj_out'):
            line('-')
            block += 1
        if cn in ('proj_in', 'add ff'): line('-   ')

In [148]:
compare_intermediate_results()

-   | level | block | cloud               | local               | equal name? | equal shape? | equal values? | mean abs Δ |   mean ± std
    |       |       |                     |                     |             |              |    prec=2     |     prec=5 |             
0   |   1   |   1   | conv1               | conv1               | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00023 | -0.39 ± 1.27
1   |   1   |   1   | add time_emb_proj   | add time_emb_proj   | [92m     y     [0m | [92m     y      [0m | [91m      n      [0m |    0.00685 | -0.26 ± 1.49
2   |   1   |   1   | conv2               | conv2               | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00008 | -0.03 ± 0.41
3   |   1   |   1   | add conv_shortcut   | add conv_shortcut   | [92m     y     [0m | [92m     y      [0m | [92m      y      [0m |    0.00019 |  0.00 ± 0.99
----------------------------------------------------------------------

In [102]:
contexts_l = [t for n,t in dlog_l if n=='context']
contexts_c = [t for n,t in dlog_c if n=='context']
len(contexts_l), len(contexts_c)

(34, 34)

In [132]:
header = ('nr','local equal to next?','equal across local/cloud?','cloud equal to next?')
print(' | '.join(header))
for i in range(len(contexts_l)-1):
    eq_next_l = (contexts_l[i]==contexts_l[i+1]).all().item()
    eq_next_c = (contexts_c[i]==contexts_c[i+1]).all().item()
    eq_across = torch.allclose(contexts_l[i],contexts_c[i], atol=1e-2)

    vals = (i,eq_next_l,eq_across,eq_next_c)
    
    print(' | '.join(f'{str(v):^{len(h)}}' for v, h in zip(vals, header)))

nr | local equal to next? | equal across local/cloud? | cloud equal to next?
0  |         True         |           False           |         True        
1  |         True         |           False           |         True        
2  |         True         |           False           |         True        
3  |         True         |           False           |         True        
4  |         True         |           False           |         True        
5  |         True         |           False           |         True        
6  |         True         |           False           |         True        
7  |         True         |           False           |         True        
8  |         True         |           False           |         True        
9  |         True         |           False           |         True        
10 |         True         |           False           |         True        
11 |         True         |           False           |         True        

In [140]:
cxt_l = contexts_l[0]
cxt_c = contexts_c[0]

In [141]:
cxt_l.flatten()[:10]

tensor([-3.892, -2.511,  4.717,  1.092, -1.341, -4.957, -2.133, -2.664,  3.387,  0.525])

In [142]:
cxt_c.flatten()[:10]

tensor([-3.892, -2.511,  4.718,  1.092, -1.341, -4.955, -2.133, -2.663,  3.388,  0.525])

In [146]:
torch.allclose(cxt_l, cxt_c, atol=1e-1)

False