In [4]:
import yaml
from typing import Dict, Any
from pathlib import Path

def parse_config(config_path: str) -> Dict[str, Any]:
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config['model']['params']

def calculate_resnet_block_params(in_channels: int, out_channels: int) -> int:
    params = 0

    params += in_channels * out_channels * 3 * 3
    params += out_channels 

    params += out_channels * out_channels * 3 * 3
    params += out_channels 
    
    if in_channels != out_channels:
        params += in_channels * out_channels + out_channels
        
    return params

def calculate_transformer_params(channels: int, context_dim: int, num_heads: int, depth: int = 1) -> int:
    head_dim = channels // num_heads
    params = 0
    
    params += channels * channels * 3 
    params += channels * channels 
    
    params += channels * channels * 3 
    params += channels * channels
    params += context_dim * channels
    
    params += channels * (channels * 4) * 2
    
    return params * depth

def calculate_controlnet_params(config: Dict[str, Any]) -> int:
    control_config = config['control_stage_config']['params']
    base_channels = control_config['model_channels']
    channel_multipliers = control_config['channel_mult']
    num_res_blocks = control_config['num_res_blocks']
    transformer_depth = control_config['transformer_depth']
    context_dim = control_config['context_dim']
    num_heads = base_channels // control_config['num_head_channels']
    
    total_params = 0
    
    total_params += control_config['in_channels'] * base_channels * 3 * 3 + base_channels
    
    total_params += control_config['hint_channels'] * base_channels * 3 * 3 + base_channels
    
    current_channels = base_channels
    for mult in channel_multipliers:
        out_channels = base_channels * mult
        
        for _ in range(num_res_blocks):
            total_params += calculate_resnet_block_params(current_channels, out_channels)
            current_channels = out_channels
        
        if mult >= 2:
            total_params += calculate_transformer_params(
                current_channels, 
                context_dim, 
                num_heads, 
                transformer_depth
            )
        
        if mult != channel_multipliers[-1]:
            total_params += current_channels * current_channels * 3 * 3 + current_channels
    
    return total_params

def calculate_unet_params(config: Dict[str, Any]) -> int:
    unet_config = config['unet_config']['params']
    base_channels = unet_config['model_channels']
    channel_multipliers = unet_config['channel_mult']
    num_res_blocks = unet_config['num_res_blocks']
    transformer_depth = unet_config['transformer_depth']
    context_dim = unet_config['context_dim']
    num_heads = base_channels // unet_config['num_head_channels']
    
    params = calculate_controlnet_params(config)
    
    middle_channels = base_channels * channel_multipliers[-1]
    params += calculate_transformer_params(
        middle_channels,
        context_dim,
        num_heads,
        transformer_depth
    )
    
    return params

def calculate_vae_params(config: Dict[str, Any]) -> int:
    vae_config = config['first_stage_config']['params']['ddconfig']
    base_channels = vae_config['ch']
    channel_multipliers = vae_config['ch_mult']
    num_res_blocks = vae_config['num_res_blocks']
    z_channels = vae_config['z_channels']
    
    total_params = 0
    
    total_params += vae_config['in_channels'] * base_channels * 3 * 3 + base_channels
    
    current_channels = base_channels
    for mult in channel_multipliers:
        out_channels = base_channels * mult
        
        for _ in range(num_res_blocks):
            total_params += calculate_resnet_block_params(current_channels, out_channels)
            current_channels = out_channels
        
        if mult != channel_multipliers[-1]:
            total_params += current_channels * current_channels * 3 * 3 + current_channels
    
    total_params *= 2
    
    if vae_config.get('double_z', False):
        total_params += z_channels * 2 * current_channels
    
    return total_params

def main(config_path: str):
    try:
        config = parse_config(config_path)
        
        controlnet_params = calculate_controlnet_params(config)
        unet_params = calculate_unet_params(config)
        vae_params = calculate_vae_params(config)
        clip_params = 123_000_000 
        
        print("\nParameter Count Breakdown:")
        print("-" * 50)
        print(f"ControlNet parameters:    {controlnet_params:,}")
        print(f"UNet parameters:          {unet_params:,}")
        print(f"VAE parameters:           {vae_params:,}")
        print(f"CLIP parameters:          {clip_params:,}")
        print("-" * 50)
        total_params = controlnet_params + unet_params + vae_params + clip_params
        print(f"Total parameters:         {total_params:,}")
        
    except Exception as e:
        print(f"Error processing config file: {str(e)}")
        raise

In [5]:
main("depth.yaml")


Parameter Count Breakdown:
--------------------------------------------------
ControlNet parameters:    195,104,000
UNet parameters:          222,629,120
VAE parameters:           42,524,928
CLIP parameters:          123,000,000
--------------------------------------------------
Total parameters:         583,258,048


In [6]:
main("sem.yaml")


Parameter Count Breakdown:
--------------------------------------------------
ControlNet parameters:    195,109,760
UNet parameters:          222,634,880
VAE parameters:           42,524,928
CLIP parameters:          123,000,000
--------------------------------------------------
Total parameters:         583,269,568


In [7]:
main("full.yaml")


Parameter Count Breakdown:
--------------------------------------------------
ControlNet parameters:    195,112,640
UNet parameters:          222,637,760
VAE parameters:           42,524,928
CLIP parameters:          123,000,000
--------------------------------------------------
Total parameters:         583,275,328


In [15]:
import yaml
import numpy as np

def calculate_params(config):
    def count_layer_params(input_dim, output_dim, kernel_size=1, stride=1, padding=0):
        return input_dim * output_dim * (kernel_size ** 2) + output_dim

    def parse_unet(config):
        unet_params = 0
        model_channels = config['model_channels']
        ch_mult = np.array(config['channel_mult'])
        attention_resolutions = config['attention_resolutions']
        num_res_blocks = config['num_res_blocks']

        base_channels = model_channels
        prev_channels = base_channels
        for i, mult in enumerate(ch_mult):
            channels = base_channels * mult

            for _ in range(num_res_blocks):
                unet_params += count_layer_params(prev_channels, channels, kernel_size=3)
                unet_params += count_layer_params(channels, channels, kernel_size=3)
                prev_channels = channels

            if 2 ** i in attention_resolutions:
                unet_params += count_layer_params(channels, channels)

        return unet_params

    def parse_control_stage(config):
        control_params = 0
        in_channels = config['in_channels']
        hint_channels = config['hint_channels']
        model_channels = config['model_channels']

        control_params += count_layer_params(hint_channels, model_channels, kernel_size=3)
        control_params += count_layer_params(in_channels, model_channels, kernel_size=3)

        ch_mult = np.array(config['channel_mult'])
        attention_resolutions = config['attention_resolutions']
        num_res_blocks = config['num_res_blocks']
        prev_channels = model_channels

        for i, mult in enumerate(ch_mult):
            channels = model_channels * mult
            for _ in range(num_res_blocks):
                control_params += count_layer_params(prev_channels, channels, kernel_size=3)
                control_params += count_layer_params(channels, channels, kernel_size=3)
                prev_channels = channels

            if 2 ** i in attention_resolutions:
                control_params += count_layer_params(channels, channels)

        return control_params

    def parse_first_stage(config):
        first_stage_params = 0
        ddconfig = config['ddconfig']
        in_channels = ddconfig['in_channels']
        out_channels = ddconfig['out_ch']
        ch = ddconfig['ch']
        ch_mult = np.array(ddconfig['ch_mult'])
        num_res_blocks = ddconfig['num_res_blocks']

        first_stage_params += count_layer_params(in_channels, ch, kernel_size=3)

        prev_channels = ch
        for mult in ch_mult:
            channels = ch * mult
            for _ in range(num_res_blocks):
                first_stage_params += count_layer_params(prev_channels, channels, kernel_size=3)
                first_stage_params += count_layer_params(channels, channels, kernel_size=3)
                prev_channels = channels

        first_stage_params += count_layer_params(prev_channels, out_channels, kernel_size=3)

        return first_stage_params

    def parse_cond_stage(config):
        return 0

    unet_config = config['model']['params']['unet_config']['params']
    control_stage_config = config['model']['params']['control_stage_config']['params']
    first_stage_config = config['model']['params']['first_stage_config']['params']
    cond_stage_config = config['model']['params']['cond_stage_config']['params']

    total_params = 0
    total_params += parse_unet(unet_config)
    total_params += parse_control_stage(control_stage_config)
    total_params += parse_first_stage(first_stage_config)
    total_params += parse_cond_stage(cond_stage_config)

    return total_params

In [10]:
with open('depth.yaml', 'r') as file:
    config = yaml.safe_load(file)
num_params = calculate_params(config)
print(f"Total number of parameters: {num_params}")

Total number of parameters: 279082051


In [11]:
with open('sem.yaml', 'r') as file:
    config = yaml.safe_load(file)
num_params = calculate_params(config)
print(f"Total number of parameters: {num_params}")

Total number of parameters: 279087811


In [12]:
with open('full.yaml', 'r') as file:
    config = yaml.safe_load(file)
num_params = calculate_params(config)
print(f"Total number of parameters: {num_params}")

Total number of parameters: 279090691
