In [5]:
from unet_torch import UNet

In [12]:
from unet_torch import UNet
import torch

# instantiate your model (fill in any args you need)
model = UNet(
    num_classes=1,
    ch=256,
    emb_ch=1024,
    out_ch=3,  # e.g., RGB output
    ch_mult=(1, 1, 1),
    num_res_blocks=3,
    attn_resolutions=(8, 16),
    num_heads=1,
    dropout=0.2,
    logsnr_input_type="inv_cos",
    resblock_resample=True,
    logsnr_scale_range=(-10,10)
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters:  {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters:  52,926,211
Trainable parameters: 52,926,211


In [1]:
from unet_jaxcode import UNet as JaxUNet

In [13]:
import jax
import jax.numpy as jnp
from flax.core import FrozenDict
from unet_jaxcode import UNet

# 1) instantiate your UNet with the same hyper-parameters you plan to train with
model = UNet(
    num_classes=1,
    ch=256,
    emb_ch=1024,
    out_ch=3,  # e.g., RGB output
    ch_mult=(1, 1, 1),
    num_res_blocks=3,
    attn_resolutions=(8, 16),
    num_heads=1,
    dropout=0.2,
    logsnr_input_type="inv_cos",
    resblock_resample=True,
    logsnr_scale_range=(-10,10)
)

# 2) create dummy inputs matching the signature
rng = jax.random.PRNGKey(0)
dummy_x     = jnp.zeros((1, 32, 32, 3), jnp.float32)
dummy_logsnr= jnp.zeros((1,),      jnp.float32)
dummy_y     = jnp.zeros((1,),      jnp.int32)  # if num_classes>1, else can pass 0

# 3) init to get variables dict, grab the 'params' sub-tree
variables = model.init(rng, dummy_x, dummy_logsnr, dummy_y, train=False)
params: FrozenDict = variables['params']

# 4) sum up all leaf sizes
param_count = sum([p.size for p in jax.tree_util.tree_leaves(params)])
print(f"Total parameters: {param_count:,}")

Total parameters: 60,004,099
