In [1]:
import torch
from src.ml_models.utils import count_params
from torch.utils.data import Dataset, DataLoader
from src.ml_models.utils import get_device

  from .autonotebook import tqdm as notebook_tqdm
2026-01-05 03:37:44,407	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
save_model_path = "model_architecture/"

## configurations

In [3]:
save_model_architecture = False
show_model_vram_usage = False

In [4]:
def vram_usage(training_function):
    torch.cuda.reset_peak_memory_stats()

    training_function()

    current = torch.cuda.memory_allocated() / 1024**2
    peak = torch.cuda.max_memory_allocated() / 1024**2

    print(f"Current VRAM: {current:.2f} MB")
    print(f"Peak VRAM: {peak:.2f} MB")

In [5]:
def get_dataloader():

    class InMemoryDictDataset(Dataset):
        def __init__(self, num_samples, num_classes=10):
            self.images = torch.rand(num_samples, 1, 28, 28)
            self.labels = torch.randint(
                0, num_classes, (num_samples,), dtype=torch.long
            )

        def __len__(self):
            return len(self.images)

        def __getitem__(self, idx):
            return {
                "image": self.images[idx],
                "label": self.labels[idx],
            }

    dataset = InMemoryDictDataset(num_samples=10000)

    return DataLoader(
        dataset,
        batch_size=64,
        shuffle=True,
    )

# CNN

In [6]:
from src.ml_models.cnn import CNN
from src.ml_models.train_net import train_net

cnn = CNN("cnn5")

if show_model_vram_usage or True:
    dataloader = get_dataloader()

    vram_usage(
        lambda: train_net(
            net=cnn,
            trainloader=dataloader,
            testloader=dataloader,
            epochs=10,
            learning_rate=0.001,
            device=get_device(),
            dataset_input_feature="image",
            dataset_target_feature="label",
            optimizer_strategy="adam",
        )
    )

    cnn.to("cpu")
    torch.cuda.empty_cache()

if save_model_architecture:
    cnn_dummy_input = torch.randn(1, 3, 32, 32)
    torch.onnx.export(cnn, cnn_dummy_input, save_model_path + "cnn.onnx")

print("Cnn Parameters:", count_params(cnn))

Current VRAM: 67.14 MB
Peak VRAM: 155.09 MB
Cnn Parameters: 6669258


# HFedCVAE

In [7]:
from src.ml_models.vae import CVAE
from src.ml_models.train_vae import train_vae

HFedCVAE = {
    "cvae_parameters": {
        "h_dim": 224,
        "res_h_dim": 56,
        "n_res_layers": 3,
        "latent_dim": 100,
    }
}

cvae = CVAE(**HFedCVAE["cvae_parameters"])

if show_model_vram_usage or True:
    dataloader = get_dataloader()

    vram_usage(
        lambda: train_vae(
            cvae=cvae,
            trainloader=dataloader,
            epochs=10,
            device=get_device(),
            dataset_input_feature="image",
            dataset_target_feature="label",
        )
    )

    cvae.to("cpu")
    torch.cuda.empty_cache()

if save_model_architecture:
    cvae_dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(cvae, cvae_dummy_input, save_model_path + "HFedCVAE_vae.onnx")

print("Vae Parameters:", count_params(cvae))

Current VRAM: 63.85 MB
Peak VRAM: 157.03 MB
Vae Parameters: 5916921


# HFedCGAN

In [None]:
from src.ml_models.train_gan import train_gan
from src.ml_models.discriminator import Discriminator
from src.ml_models.generator import Generator

HFedCGAN = {
    "generator_parameters": {
        "n_block_layers": 2,
        "h_dim": 120,
        "latent_dim": 100,
        "init_img_dim": 7,
    },
    "discriminator_parameters": {
        "block_repeat": 1,
        "n_block_layers": 3,
        "h_dim": 42,
    },
}

generator = Generator(**HFedCGAN["generator_parameters"])
discriminator = Discriminator(**HFedCGAN["discriminator_parameters"])

if show_model_vram_usage or True:
    dataloader = get_dataloader()

    vram_usage(
        lambda: train_gan(
            generator=generator,
            discriminator=discriminator,
            trainloader=dataloader,
            epochs=10,
            device=get_device(),
            dataset_input_feature="image",
        )
    )

    generator.to("cpu")
    discriminator.to("cpu")
    torch.cuda.empty_cache()

if save_model_architecture:
    generator_dummy_input = torch.randn(1, 100)
    torch.onnx.export(
        generator, generator_dummy_input, save_model_path + "HFedCGAN_generator.onnx"
    )
    discriminator_dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(
        discriminator,
        discriminator_dummy_input,
        save_model_path + "HFedCGAN_discriminator.onnx",
    )

generator_parameters_count = count_params(generator)
discriminator_parameters_count = count_params(discriminator)

print("Generator Parameters:", generator_parameters_count)
print("Discriminator Parameters:", discriminator_parameters_count)
print(
    "Total GAN Parameters:",
    generator_parameters_count + discriminator_parameters_count,
)

Current VRAM: 26.57 MB
Peak VRAM: 151.22 MB
Generator Parameters: 675421
Discriminator Parameters: 674353
Total GAN Parameters: 1349774


# HFedCVAEGAN

In [None]:
from src.ml_models.vae import VAE
from src.ml_models.discriminator import Discriminator
from src.ml_models.train_vae_gan import train_vae_gan

HFedCVAEGAN = {
    "vae_parameters": {
        "h_dim": 210,
        "res_h_dim": 56,
        "n_res_layers": 2,
        "latent_dim": 100,
    },
    "discriminator_parameters": {
        "block_repeat": 1,
        "n_block_layers": 3,
        "h_dim": 30,
    },
}

vae = VAE(**HFedCVAEGAN["vae_parameters"])
discriminator = Discriminator(**HFedCVAEGAN["discriminator_parameters"])

if show_model_vram_usage or True:
    dataloader = get_dataloader()

    vram_usage(
        lambda: train_vae_gan(
            vae=vae,
            discriminator=discriminator,
            trainloader=dataloader,
            epochs=10,
            device=get_device(),
            dataset_input_feature="image",
        )
    )

    vae.to("cpu")
    discriminator.to("cpu")
    torch.cuda.empty_cache()

if save_model_architecture:
    vae_dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(vae, vae_dummy_input, save_model_path + "HFedCVAEGAN_vae.onnx")
    discriminator_dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(
        discriminator,
        discriminator_dummy_input,
        save_model_path + "HFedCVAEGAN_discriminator.onnx",
    )

vae_parameters_count = count_params(vae)
discriminator_parameters_count = count_params(discriminator)

print("Vae Parameters:", vae_parameters_count)
print("Discriminator Parameters:", discriminator_parameters_count)
print(
    "Total VAE-GAN Parameters:",
    vae_parameters_count + discriminator_parameters_count,
)

Current VRAM: 28.94 MB
Peak VRAM: 103.82 MB
Vae Parameters: 1808301
Discriminator Parameters: 345601
Total VAE-GAN Parameters: 2153902


# Save

In [22]:
from src.scripts.helper import save_metadata

save_metadata(
    {
        "HFedCVAE": HFedCVAE,
        "HFedCGAN": HFedCGAN,
        "HFedCVAEGAN": HFedCVAEGAN,
    }
)