In [2]:
from dataclasses import dataclass
from typing import List, Tuple, Union

# ===== Capas soportadas =====
@dataclass
class Dense:
    units: int
    input_dim: int
    use_bias: bool = True

    def params(self) -> int:
        # pesos: input_dim * units, bias: units (opcional)
        return self.input_dim * self.units + (self.units if self.use_bias else 0)

@dataclass
class Conv2D:
    filters: int
    kernel_size: Tuple[int, int]  # (kh, kw)
    in_channels: int
    use_bias: bool = True

    def params(self) -> int:
        kh, kw = self.kernel_size
        # pesos por filtro: in_channels * kh * kw
        # total: filters * (in_channels * kh * kw + bias?)
        return self.filters * (self.in_channels * kh * kw + (1 if self.use_bias else 0))

Layer = Union[Dense, Conv2D]

# ===== Utilidades generales =====
def count_parameters(layers: List[Layer]) -> int:
    """Cuenta parámetros entrenables (pesos + bias) de la red."""
    return sum(layer.params() for layer in layers)

def count_mlp_connections(units_per_layer: List[int], include_bias: bool = True) -> int:
    """
    Para un MLP definido por [n0, n1, n2, ..., nL], cuenta:
      - Pesos: sum_i n_i * n_{i+1}
      - Bias (opcional): sum_i n_{i+1}
    """
    if len(units_per_layer) < 2:
        return 0
    weights = sum(units_per_layer[i] * units_per_layer[i+1] for i in range(len(units_per_layer)-1))
    biases  = sum(units_per_layer[i+1] for i in range(len(units_per_layer)-1)) if include_bias else 0
    return weights + biases

# ===== Ejemplos de uso =====
if __name__ == "__main__":
    # 1) MLP: 784 -> 128 -> 64 -> 10 (típico para MNIST)
    mlp_arch = [784, 128, 64, 10]
    mlp_con = count_mlp_connections(mlp_arch, include_bias=True)
    print(f"MLP {mlp_arch} con bias: {mlp_con} parámetros")

    mlp_con_sin_bias = count_mlp_connections(mlp_arch, include_bias=False)
    print(f"MLP {mlp_arch} sin bias: {mlp_con_sin_bias} parámetros")

    # 2) CNN simple: Conv(32, 3x3, in_channels=3) -> Conv(64, 3x3, in_channels=32) -> Dense(128, input_dim=3136)
    cnn_layers = [
        Conv2D(filters=32, kernel_size=(3,3), in_channels=3,  use_bias=True),   # p = 32*(3*3*3 + 1)
        Conv2D(filters=64, kernel_size=(3,3), in_channels=32, use_bias=True),   # p = 64*(32*3*3 + 1)
        Dense(units=128, input_dim=3136, use_bias=True),                         # p = 3136*128 + 128
        Dense(units=10,  input_dim=128,  use_bias=True)                          # p = 128*10 + 10
    ]
    total_params_cnn = count_parameters(cnn_layers)
    print(f"Total parámetros CNN: {total_params_cnn}")


MLP [784, 128, 64, 10] con bias: 109386 parámetros
MLP [784, 128, 64, 10] sin bias: 109184 parámetros
Total parámetros CNN: 422218
