Errors were produced in the control model attentions.

I had the bug that I had used the wrong number of attention heads. This is fixed now.

Let's see if this reducedes the errors.

In [1]:
import torch
from torch.testing import assert_close
from torch import allclose, nn, tensor
torch.set_printoptions(linewidth=200, precision=3, sci_mode=False)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'mps'
device_dtype = torch.float16 if device == 'cuda' else torch.float32

## Load logs

In [3]:
from diffusers.umer_debug_logger import UmerDebugLogger

In [4]:
cloud_cuda = UmerDebugLogger.load_log_objects_from_dir('logs/cloud')
local_cuda = UmerDebugLogger.load_log_objects_from_dir('logs/local_cuda')

In [5]:
len(cloud_cuda), len(local_cuda)

(970, 970)

In [6]:
for i, (c,l) in enumerate(zip(cloud_cuda, local_cuda)):
    if c.msg!=l.msg: print(f'{i:<3}{c.msg:>20}{l.msg:>20}')

## Compare intermediate results

In [7]:
def mae(t1,t2):
    assert t1.shape==t2.shape
    return (t1-t2).abs().mean()

In [16]:
from functools import partial
from util_inspect import fmt_bool

def compare_intermediate_results(n=None,n_start=0,prec=5, compare_prec=3, ignore_base=False):
    if n is None: n=max(len(cloud_cuda), len(local_cuda))

    print(f'{"":<3} | {"block":<21} | {"name":<20} | {"shape":<20} | {"same names?":<12} | {"same shapes?":<12} | {"same values?":<12} | {"Δ cuda local -> cloud":<20}')
    print(f'{"":<3} | {"":<21} | {"":<20} | {"":<20} | {"":<12} | {"":<12} | {"prec="+str(compare_prec):^12} | {"prec="+str(prec):^20}')

    def calc_total_len(lens): return sum(lens)+3*len(lens)-1
    total_len = calc_total_len((3,20,20,20,12,12,12,20))

    line = partial(
        lambda txt, width: print(txt * (width//len(txt))),
        width=total_len
    )
    
    labels = []
    def add_label(lbs, ctrl=True):
        if not isinstance(lbs, (list, tuple)): lbs = [lbs]
        for l in lbs:
            labels.append(('Base',l))
        for l in lbs: 
            if ctrl: labels.append(('Ctrl',l))
    
    # # down
    # 1
    add_label('ResBlock d1.1')
    add_label('ResBlock d1.2')
    add_label('Conv d1')
    # 2
    add_label(('ResBlock d2.1', 'AttnBlock d2.1'))
    add_label(('ResBlock d2.2', 'AttnBlock d2.2'))
    add_label('Conv d2')
    # 3
    add_label(('ResBlock d3.1', 'AttnBlock d3.1'))
    add_label(('ResBlock d3.2', 'AttnBlock d3.2')) 
    # # mid
    add_label(('ResBlock m1', 'AttnBlock m', 'ResBlock mid2'))
    # # up
    for _ in range(1000): add_label('DONT CARE', ctrl=False)
    
    line('#')
    bc,block=labels.pop(0)
    for i in range(n_start,n):
        cc,lc = cloud_cuda[i], local_cuda[i]
                
        eq_name = cc.msg==lc.msg
        eq_shape = cc.shape==lc.shape
        eq_vals = torch.allclose(cc.t,lc.t,atol=10**-compare_prec)

        mae_2 = mae(lc.t,cc.t)        
        mae_2 = ("{:>20."+str(prec)+"f}").format(mae_2)
        
        if not (ignore_base and bc=='Base'):
            print(f'{i+1:<3} | {bc:<4} | {block:<14} | {cc.msg:<20} | {cc.shape:>20} | {fmt_bool(eq_name, "^12")} | {fmt_bool(eq_shape, "^12")} | {fmt_bool(eq_vals, "^12")} | {mae_2}')

        if cc.msg in ('add conv_shortcut','conv','proj_out'):
            line('=')
            bc,block=labels.pop(0)
        elif cc.msg in ('add ff','proj_in'): line('- ')

In [17]:
compare_intermediate_results(compare_prec=3, prec=3, ignore_base=True)

    | block                 | name                 | shape                | same names?  | same shapes? | same values? | Δ cuda local -> cloud
    |                       |                      |                      |              |              |    prec=3    |        prec=3       
##############################################################################################################################################
5   | Ctrl | ResBlock d1.1  | conv1                |      [2, 32, 96, 96] | [92m     y      [0m | [92m     y      [0m | [91m     n      [0m |                0.000
6   | Ctrl | ResBlock d1.1  | add time_emb_proj    |      [2, 32, 96, 96] | [92m     y      [0m | [92m     y      [0m | [91m     n      [0m |                0.000
7   | Ctrl | ResBlock d1.1  | conv2                |      [2, 32, 96, 96] | [92m     y      [0m | [92m     y      [0m | [92m     y      [0m |                0.000
8   | Ctrl | ResBlock d1.1  | add conv_shortcut    |      [2, 32,

Yeah! The errors in the control attentions are gone now.

____

Okay, next: There's a huge error in `ctrl`.`mid`.`resnet1`.`conv1`