In [85]:
# autoreload 
%load_ext autoreload
%autoreload 2
%matplotlib inline
from uio.unified_io_hf import image_processing_uio as ip_uio 
import torch
from PIL import Image
import urllib 
import numpy as np

def load_image_from_url(url):
    with urllib.request.urlopen(url) as f:
        img = Image.open(f)
        return np.array(img)

from flax.serialization import from_bytes

# add to sys path
import sys
sys.path.append("./")
from torch import nn

from uio import network as nw_jax 
from uio.unified_io_hf import modeling_uio_vae as nw_torch 
import jax.numpy as jnp 
import jax

from uio.t5x_layers import Conv
from flax.core.frozen_dict import FrozenDict

def convert_params_to_ones(params):
    params = dict(params)
    for k, v in params.items():
        if isinstance(v, jnp.ndarray):
            params[k] = jnp.ones_like(v)
            print(k, v.shape)
        elif isinstance(v, dict) or isinstance(v, FrozenDict): 
            params[k] = convert_params_to_ones(v)
        else:
            params[k] = v
    return params

def load_checkpoint(checkpoint):
  """Load a bin file as a tree of jax arrays"""
  with open(checkpoint, "rb") as state_f:
    state = from_bytes(None, state_f.read())
  state = jax.tree_util.tree_map(jnp.array, state)
  return state


def flatten_dict(d):
    """
    Takes an input dict and concatenates all nested keys into a single key using '.' as a separator.
    """
    out = {}
    for k, v in d.items():
        if isinstance(v, dict):
            v = flatten_dict(v)
            for k2, v2 in v.items():
                out[k + '.' + k2] = v2
        else:
            out[k] = v
    return out

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
path = "./ckpts/small.bin"
weights = load_checkpoint(path)

In [694]:
from uio.unified_io_hf.modeling_uio_vae import VAEConfig
DTYPE = "float32"
VAE_CONFIG = VAEConfig(
  embed_dim=256,
  n_embed=16384,
  double_z=False,
  z_channels=256,
  resolution=256,
  in_channels=3,
  out_ch=3,
  ch=128,
  ch_mult=(1,1,2,2,4),
  num_res_blocks=2,
  attn_resolutions=(16,),
  dropout=0,
  dtype=DTYPE,
)

## Weight conversions

In [584]:
def translate_weights_downsample_block(weights_jax):
    weights_torch = {} 
    for k, v in weights_jax.items():
        if "kernel" in k:
            #gotta be careful and transpose the kernel dims coz their impl is different
            weights_torch[k.replace("kernel", "weight")] = torch.from_numpy(np.array(v).transpose(3, 2, 0, 1)).float()
        elif "bias" in k:
            weights_torch[k] = torch.from_numpy(np.array(v)).float()
    return weights_torch
    
def translate_weights_upsample_block(weights_jax):
    weights_torch = {} 
    for k, v in weights_jax.items():
        if "kernel" in k:
            #gotta be careful and transpose the kernel dims coz their impl is different
            weights_torch[k.replace("kernel", "weight")] = torch.from_numpy(np.array(v).transpose(3,2,0,1)).float()
        elif "bias" in k:
            weights_torch[k] = torch.from_numpy(np.array(v)).float()
    return weights_torch


def translate_weights_res_block(weights_jax):
    weights_torch = {}
    for k, v in weights_jax.items():
        if "kernel" in k:
            weights_torch[k.replace("kernel", "weight")] = torch.from_numpy(np.array(v).transpose(3,2,0,1)).float()
        elif "scale" in k:
            weights_torch[k.replace("scale", "weight")] = torch.from_numpy(np.array(v)).float()
        elif "bias" in k:
            weights_torch[k] = torch.from_numpy(np.array(v)).float()
    return weights_torch

# vae encoder block

In [740]:
# given the outputs below, write a function to translate weights from jax to torch

def translate_weights_vae_encoder(weights_jax):
    weights_torch = {}
    for k, v in weights_jax.items():

        if k == "down_2_block_0.norm1.scale":

            flag = True
        if "down_" in k and ("block_" in k or "attn_" in k):
            idxs = k.split(".")[0].split("_")
            idxs = [idxs[1], idxs[3]]
            if "down_" and "block_" in k:
                k = k.replace("down_" + idxs[0] + "_block_" + idxs[1], "res_blocks." + idxs[0] + "." + idxs[1])
            if "down_" and "attn_" in k:
                k = k.replace("down_" + idxs[0] + "_attn_" + idxs[1], "attn_blocks." + idxs[0] + "." + idxs[1])


        if "downsample" in k:
            idxs = k.split(".")[0].split("_")
            idxs = [idxs[1]]
            k = k.replace(f"down_{idxs[0]}_downsample", f"downsamples.{idxs[0]}")
        if "kernel" in k:
            k = k.replace("kernel", "weight")
            weights_torch[k] = torch.from_numpy(np.array(v).transpose(3,2, 0, 1)).float()
        elif "scale" in k:
            k = k.replace("scale", "weight")
            weights_torch[k.replace("scale", "weight")] = torch.from_numpy(np.array(v)).float()
        elif "bias" in k:
            weights_torch[k] = torch.from_numpy(np.array(v)).float()
        print("Jax", v.shape, "tORCH", state_dict[k].shape, k)
    return weights_torch
flat_weights = flatten_dict(weights["discrete_vae"]["encoder"])
weights_torch= translate_weights_vae_encoder(flat_weights)

Jax (128,) tORCH torch.Size([128]) conv_in.bias
Jax (3, 3, 3, 128) tORCH torch.Size([128, 3, 3, 3]) conv_in.weight
Jax (256,) tORCH torch.Size([256]) conv_out.bias
Jax (3, 3, 512, 256) tORCH torch.Size([256, 512, 3, 3]) conv_out.weight
Jax (128,) tORCH torch.Size([128]) res_blocks.0.0.conv1.bias
Jax (3, 3, 128, 128) tORCH torch.Size([128, 128, 3, 3]) res_blocks.0.0.conv1.weight
Jax (128,) tORCH torch.Size([128]) res_blocks.0.0.conv2.bias
Jax (3, 3, 128, 128) tORCH torch.Size([128, 128, 3, 3]) res_blocks.0.0.conv2.weight
Jax (128,) tORCH torch.Size([128]) res_blocks.0.0.norm1.bias
Jax (128,) tORCH torch.Size([128]) res_blocks.0.0.norm1.weight
Jax (128,) tORCH torch.Size([128]) res_blocks.0.0.norm2.bias
Jax (128,) tORCH torch.Size([128]) res_blocks.0.0.norm2.weight
Jax (128,) tORCH torch.Size([128]) res_blocks.0.1.conv1.bias
Jax (3, 3, 128, 128) tORCH torch.Size([128, 128, 3, 3]) res_blocks.0.1.conv1.weight
Jax (128,) tORCH torch.Size([128]) res_blocks.0.1.conv2.bias
Jax (3, 3, 128, 128)

In [737]:
state_dict = vae_encoder_torch.state_dict()
total_sum = 0
for k, v in state_dict.items():
    if "norm1" in k:
        print(k, v.shape)
        total_sum+=v.shape[0]
total_sum


res_blocks.0.0.norm1.weight torch.Size([128])
res_blocks.0.0.norm1.bias torch.Size([128])
res_blocks.0.1.norm1.weight torch.Size([128])
res_blocks.0.1.norm1.bias torch.Size([128])
res_blocks.1.0.norm1.weight torch.Size([128])
res_blocks.1.0.norm1.bias torch.Size([128])
res_blocks.1.1.norm1.weight torch.Size([128])
res_blocks.1.1.norm1.bias torch.Size([128])
res_blocks.2.0.norm1.weight torch.Size([256])
res_blocks.2.0.norm1.bias torch.Size([256])
res_blocks.2.1.norm1.weight torch.Size([256])
res_blocks.2.1.norm1.bias torch.Size([256])
res_blocks.3.0.norm1.weight torch.Size([256])
res_blocks.3.0.norm1.bias torch.Size([256])
res_blocks.3.1.norm1.weight torch.Size([256])
res_blocks.3.1.norm1.bias torch.Size([256])
res_blocks.4.0.norm1.weight torch.Size([512])
res_blocks.4.0.norm1.bias torch.Size([512])
res_blocks.4.1.norm1.weight torch.Size([512])
res_blocks.4.1.norm1.bias torch.Size([512])
mid_block_1.norm1.weight torch.Size([512])
mid_block_1.norm1.bias torch.Size([512])
mid_block_2.norm

7168

In [730]:
flat_weights = flatten_dict(weights["discrete_vae"]["encoder"])
total_sum = 0
for k, v in flat_weights.items():
    if "norm1" in k:
        print(k, v.shape)
        total_sum+=v.shape[0]
total_sum

down_0_block_0.norm1.bias (128,)
down_0_block_0.norm1.scale (128,)
down_0_block_1.norm1.bias (128,)
down_0_block_1.norm1.scale (128,)
down_1_block_0.norm1.bias (128,)
down_1_block_0.norm1.scale (128,)
down_1_block_1.norm1.bias (128,)
down_1_block_1.norm1.scale (128,)
down_2_block_0.norm1.bias (128,)
down_2_block_0.norm1.scale (128,)
down_2_block_1.norm1.bias (256,)
down_2_block_1.norm1.scale (256,)
down_3_block_0.norm1.bias (256,)
down_3_block_0.norm1.scale (256,)
down_3_block_1.norm1.bias (256,)
down_3_block_1.norm1.scale (256,)
down_4_block_0.norm1.bias (256,)
down_4_block_0.norm1.scale (256,)
down_4_block_1.norm1.bias (512,)
down_4_block_1.norm1.scale (512,)
mid_block_1.norm1.bias (512,)
mid_block_1.norm1.scale (512,)
mid_block_2.norm1.bias (512,)
mid_block_2.norm1.scale (512,)


6400

In [741]:
vae_encoder_torch = nw_torch.VAE_Encoder(VAE_CONFIG)
vae_encoder_torch.load_state_dict(weights_torch)

128 128
128 128
128 256
256 256
256 512


RuntimeError: Error(s) in loading state_dict for VAE_Encoder:
	size mismatch for res_blocks.2.0.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for res_blocks.2.0.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for res_blocks.4.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for res_blocks.4.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).

In [700]:
state_dict = vae_encoder_torch.state_dict()

for k, v in state_dict.items():
    print(k, v.shape)

conv_in.weight torch.Size([128, 3, 3, 3])
conv_in.bias torch.Size([128])
res_blocks.0.0.conv1.weight torch.Size([128, 128, 3, 3])
res_blocks.0.0.conv1.bias torch.Size([128])
res_blocks.0.0.conv2.weight torch.Size([128, 128, 3, 3])
res_blocks.0.0.conv2.bias torch.Size([128])
res_blocks.0.0.norm1.weight torch.Size([128])
res_blocks.0.0.norm1.bias torch.Size([128])
res_blocks.0.0.norm2.weight torch.Size([128])
res_blocks.0.0.norm2.bias torch.Size([128])
res_blocks.0.1.conv1.weight torch.Size([128, 128, 3, 3])
res_blocks.0.1.conv1.bias torch.Size([128])
res_blocks.0.1.conv2.weight torch.Size([128, 128, 3, 3])
res_blocks.0.1.conv2.bias torch.Size([128])
res_blocks.0.1.norm1.weight torch.Size([128])
res_blocks.0.1.norm1.bias torch.Size([128])
res_blocks.0.1.norm2.weight torch.Size([128])
res_blocks.0.1.norm2.bias torch.Size([128])
res_blocks.1.0.conv1.weight torch.Size([128, 128, 3, 3])
res_blocks.1.0.conv1.bias torch.Size([128])
res_blocks.1.0.conv2.weight torch.Size([128, 128, 3, 3])
res_b

## downsample block done

In [605]:
flat_weights = flatten_dict(weights["discrete_vae"]["encoder"]["down_1_downsample"])
weights_torch = translate_weights_downsample_block(flat_weights)

In [606]:
np.random.seed(3)
image_input = np.random.randn(1, 16, 16, 128)
image_input_torch = torch.from_numpy(image_input).float()
image_input_torch = image_input_torch.permute(0, 3, 1, 2)
downsample = nw_torch.Downsample(128)
downsample.load_state_dict(weights_torch, strict=False)
output_torch = downsample(image_input_torch).detach().numpy()

In [607]:
downsample_jax = nw_jax.Downsample(128)
key = jax.random.PRNGKey(0)
# init the module with constant weights
params = downsample_jax.init(key, image_input)
params = dict(params)
params["params"] = weights["discrete_vae"]["encoder"]["down_1_downsample"]
output_jax = downsample_jax.apply(params, image_input, )

In [608]:
output_torch.shape, output_jax.shape

((1, 128, 8, 8), (1, 8, 8, 128))

In [609]:
output_torch = output_torch.transpose(0, 2, 3, 1)

np.abs(output_torch - output_jax).mean()

0.00028494012

# res block, 


In [653]:


flat_weights = flatten_dict(weights["discrete_vae"]["encoder"]["mid_block_1"])
weights_torch = translate_weights_res_block(flat_weights)

In [654]:
for k, v in flat_weights.items():
    print(k, v.shape)


conv1.bias (512,)
conv1.kernel (3, 3, 512, 512)
conv2.bias (512,)
conv2.kernel (3, 3, 512, 512)
norm1.bias (512,)
norm1.scale (512,)
norm2.bias (512,)
norm2.scale (512,)


In [660]:
image_input = np.random.randn(1, 16, 16, 512)
image_input_torch = torch.from_numpy(image_input).float()
image_input_torch = image_input_torch.permute(0, 3, 1, 2)
res_block_torch = nw_torch.ResBlock(512, 512)
res_block_torch.load_state_dict(weights_torch, strict=True)
output_torch = res_block_torch(image_input_torch).detach().numpy()
output_torch.shape

(1, 512, 16, 16)

In [661]:
res_block_jax = nw_jax.ResBlock(512, 512)
key = jax.random.PRNGKey(0)
# init the module with constant weights
params = res_block_jax.init(key, image_input)
params = dict(params)
params["params"] = weights["discrete_vae"]["encoder"]["mid_block_1"]
output_jax = res_block_jax.apply(params, image_input, )

In [662]:
output_torch = output_torch.transpose(0, 2, 3, 1)

np.abs(output_torch - output_jax).mean()

5.1574694e-05

# Upsample bloc, done

In [632]:
flat_weights = flatten_dict(weights["discrete_vae"]["decoder"]["up_1_upsample"])
weights_torch = translate_weights_upsample_block(flat_weights)
np.random.seed(3)
image_input = np.random.randn(1, 16, 16, 256)
image_input_torch = torch.from_numpy(image_input).float()
image_input_torch = image_input_torch.permute(0, 3, 1, 2)
upsample = nw_torch.Upsample(256)
upsample.load_state_dict(weights_torch, strict=False)
output_torch = upsample(image_input_torch).detach().numpy()


In [633]:
upsample_jax = nw_jax.Upsample(256)
key = jax.random.PRNGKey(0)
# init the module with constant weights
params = upsample_jax.init(key, image_input)
params = dict(params)
params["params"] = weights["discrete_vae"]["decoder"]["up_1_upsample"]
output_jax = upsample_jax.apply(params, image_input, )

In [634]:
output_torch = output_torch.transpose(0, 2, 3, 1)

np.abs(output_torch - output_jax).mean()

7.232075e-07

## Testing attnblock, done

In [672]:
def translate_weights_attn_block(weights_jax):
    # v is a jax numpy array, turn it into a numpy array
    weights_torch = {}
    for k, v in weights_jax.items():
        if "kernel" in k:
            weights_torch[k.replace("kernel", "weight")] = torch.from_numpy(np.array(v).transpose(3,2, 0, 1)).float()
        elif "scale" in k:
            weights_torch[k.replace("scale", "weight")] = torch.from_numpy(np.array(v)).float()
        elif "bias" in k:
            weights_torch[k] = torch.from_numpy(np.array(v)).float()
    return weights_torch

In [673]:
weights_attn_jax = weights["discrete_vae"]["encoder"]["mid_attn_1"]
weights_attn_torch = translate_weights_attn_block(flatten_dict(weights["discrete_vae"]["encoder"]["mid_attn_1"]))

In [686]:
# set numpy random seed 
np.random.seed(5)
image_input = np.random.randn(1, 12, 12, 512).astype(np.float32)*10
image_input_torch = torch.from_numpy(image_input).float()

attn_block_torch = nw_torch.AttnBlock(512)

# set all weights to 1
attn_block_torch.load_state_dict(weights_attn_torch)
# put in channels first 
image_input_torch = image_input_torch.permute(0, 3, 1, 2)
output_torch = attn_block_torch(image_input_torch).detach().numpy()

In [687]:
# state_dict = attn_block_torch.state_dict()
# for key in state_dict.keys():
#     print(key, state_dict[key].shape)

In [688]:
# test forward on jax, on image_input,

attn_block_jax = nw_jax.AttnBlock(512)
key = jax.random.PRNGKey(0)
# init the module with constant weights
params = attn_block_jax.init(key, image_input)
params = dict(params)
params["params"] = weights_attn_jax
output_jax = attn_block_jax.apply(params, image_input, )
output_jax.shape, output_torch.shape

((1, 12, 12, 512), (1, 512, 12, 12))

In [689]:
output_jax = output_jax.transpose(0, 3, 1, 2)

In [691]:
output_torch[0][0], output_jax[0][0]

(array([[-39139.152, -39131.5  , -39132.508, -39140.992, -39149.39 ,
         -39158.906, -39147.14 ,   9580.306, -39163.58 , -39140.184,
         -39138.94 , -39151.848],
        [-39127.1  , -39147.492, -39154.066, -39141.39 , -39139.6  ,
         -39145.734, -39140.53 ,   9573.505, -39150.555, -39151.242,
         -39130.66 , -39138.492],
        [-39148.37 , -39109.05 , -39135.023,   9574.731, -39156.473,
         -39143.582, -39156.36 , -39155.285, -39155.53 , -39129.875,
         -39136.133, -39135.23 ],
        [-39145.9  , -39159.402, -39151.387, -39126.402, -39152.523,
         -39136.97 , -39153.82 , -39141.363, -39150.344, -39133.797,
         -39133.797, -39153.863],
        [-39135.156, -39147.402, -39155.117, -39127.855, -39140.11 ,
         -39160.223, -39162.7  , -39144.88 , -39153.9  , -39164.562,
         -39149.133, -39143.164],
        [-39166.375, -39158.64 , -39155.27 , -39130.42 , -39145.188,
         -39135.383, -39131.742, -39140.27 , -39158.87 , -39151.254,
  

In [692]:
output_torch.shape

(1, 512, 12, 12)

In [693]:
np.abs(output_torch - output_jax).mean()

0.35482204

In [19]:
vae_model_jax = DiscreteVAEjax(VAE_CONFIG)
# init the module's weights
import jax 
# import jax np 

vae_model_jax.init(jax.random.PRNGKey(0), processed[0]);

In [23]:
# test vae model on a single image made of zeros 
import jax.numpy as jnp

zeros = jnp.zeros((1, 3, 256, 256), dtype=DTYPE)
vae_model_jax.apply(zeros);


AttributeError: 'Array' object has no attribute 'items'

In [35]:
vae_model = DiscreteVAE(VAE_CONFIG)

In [36]:
import torch
input_img = torch.randn(1, 3, 256, 256, dtype=torch.float32)


In [133]:
in_channels = 512
out_channels = 512 

input_img = np.random.randn(1, 512, 4, 4).astype(np.float32)
input_img_torch = torch.from_numpy(input_img).float()

input_img_jax = input_img.transpose(0, 2, 3, 1)

print(input_img_jax.shape)


(1, 4, 4, 512)


In [134]:
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding="same", bias=False)

In [135]:
weight = conv2d.weight.detach().numpy()

In [136]:
conv2d_jax = Conv(features=512,
            kernel_size=(1, 1),
            dtype=float,
            kernel_axes=("axis_0", "axis_1", "axis_2", "axis_3"),
            bias_axes=("axis_3",),
            name="k",
            )

key = jax.random.PRNGKey(0)
# init the module with constant weights
params = conv2d_jax.init(key, input_img_jax)
params = dict(params)
params["params"] = dict(params["params"])
params["params"]["kernel"] =  jnp.array(weight.transpose(2, 3, 1, 0))
output_jax = conv2d_jax.apply(params, input_img_jax, )

In [137]:
output_torch = conv2d(input_img_torch).detach().numpy()

In [138]:
output_torch.shape, output_jax.shape

((1, 512, 4, 4), (1, 4, 4, 512))

In [139]:
output_torch = output_torch.transpose(0, 2, 3, 1)

np.abs(output_torch - output_jax).mean()

0.00013946078

In [140]:
output_torch

array([[[[-0.37397608, -0.18934363, -0.19102705, ..., -0.11944845,
          -0.20906502, -1.5298786 ],
         [ 0.31949556,  1.0684662 , -0.07598345, ..., -0.3701464 ,
           0.22494976, -0.8975331 ],
         [-0.541537  , -0.8052794 , -0.18761101, ...,  0.5390114 ,
          -0.44420135,  0.03182103],
         [-0.38844013, -0.20471531,  0.18559667, ...,  0.19249475,
          -1.0414706 ,  0.54464597]],

        [[ 0.473306  ,  0.93011314, -0.35211036, ..., -0.4534372 ,
           0.63449764, -1.146805  ],
         [-0.25913465,  0.5615176 ,  0.58334273, ..., -0.46822223,
          -0.16222015,  0.27533892],
         [ 0.01665792,  0.21350853,  0.43480274, ...,  0.11611384,
          -0.42272973,  0.41423053],
         [ 0.60471356,  0.19033806,  0.14340717, ...,  0.3603496 ,
           0.9142468 , -0.82812697]],

        [[ 0.50708675,  0.25339413,  0.1398566 , ..., -0.14297803,
           0.71717924,  0.20662865],
         [ 0.5424281 , -1.0149813 ,  0.61394894, ...,  0.354