This notebook has code snippets that show how to check the number of parameters of our models

In [None]:
import os
import sys
import torch
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import himyb.models.ddpmpp as ddpmpp
import torchinfo
%load_ext autoreload
%autoreload 2

## 32x32 case

In [None]:
model = ddpmpp.DDPMPP(
    img_resolution=32,
    in_channels=3,
    out_channels=3,
    label_dim=11,
    channel_mult=[2,2,2],
    model_channels=128,
    verbose_init=True,
)

In [None]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

In [None]:
torchinfo.summary(model, [(10,3,32,32), (10,), (10,11)])

In [None]:
label_dim = 3
model_res_block2 = ddpmpp.DDPMPP(
    img_resolution=32,
    in_channels=3,
    out_channels=3,
    label_dim=label_dim,
    channel_mult=[2,2,2],
    model_channels=64,
    verbose_init=False,
    num_blocks=2
)

In [None]:
torchinfo.summary(model_res_block2, [(32,3,32,32), (32,), (32,label_dim)])

In [None]:
model_downsized = ddpmpp.DDPMPP(
    img_resolution=32,
    in_channels=3,
    out_channels=3,
    label_dim=11,
    channel_mult=[2,2,2],
    model_channels=32,
    verbose_init=False,
    num_blocks=3
)
num_params = sum(p.numel() for p in model_downsized.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

## 64x64 case

In [None]:
model_waterbirds = ddpmpp.DDPMPP(
    img_resolution=64,
    in_channels=3,
    out_channels=3,
    label_dim=3,
    channel_mult=[1,2,2,2],
    model_channels=128,
    verbose_init=False,
    num_blocks=4
)
num_params = sum(p.numel() for p in model_waterbirds.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

In [None]:
batch_size = 64
torchinfo.summary(model_waterbirds, [(batch_size,3,64,64), (batch_size,), (batch_size,label_dim)])