In [1]:
import climate_learn as cl
import torch.nn as nn

In [2]:
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
resnet = cl.models.hub.ResNet(
    in_channels=36,
    out_channels=3,
    history=3,
    n_blocks=19
)

In [4]:
unet = cl.models.hub.Unet(
    in_channels=36,
    out_channels=3,
    history=3,
    ch_mults=[1,1,2],
    n_blocks=4
)

In [15]:
vit = cl.models.hub.VisionTransformer(
    img_size=(32,64),
    in_channels=36,
    out_channels=3,
    history=3,
    patch_size=2,
    embed_dim=256,
    depth=8,
    decoder_depth=2,
    num_heads=16,
    mlp_ratio=4
)

In [16]:
resnet_params = count_parameters(resnet)
unet_params = count_parameters(unet)
vit_params = count_parameters(vit)
print(f"# of Resnet params: {resnet_params/1e6:.2f}M")
print(f"# of Unet params: {unet_params/1e6:.2f}M")
print(f"# of ViT params: {vit_params/1e6:.2f}M")

# of Resnet params: 6.31M
# of Unet params: 6.11M
# of ViT params: 6.57M
