In [32]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from __future__ import annotations
from typing import Callable, Optional, Literal
import numpy as np

In [33]:
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

X_train = np.hstack([np.ones((X_train.shape[0], 1)), X_train])
X_test  = np.hstack([np.ones((X_test.shape[0], 1)), X_test])

In [None]:
Activation = Callable[[np.ndarray], np.ndarray]
LossFn     = Callable[[np.ndarray, np.ndarray], float]


class ANN:
    """
    Build a feedforward net with flexible sizing.

    Parameters
    ----------
    input_size : int
        Number of input features.
    output_size : int
        Number of outputs.
    net_length : int
        Number of hidden layers.
    net_width : int | list[int]
        - If int: every hidden layer has this many neurons (requires `net_length`).
        - If list[int]: each entry is the width of that hidden layer (net_length inferred).
    hidden_activation : Activation | Literal["relu","sigmoid","tanh","leaky_relu"]
        Activation used on hidden layers (callable or known string).
    output_activation : Activation | Literal["identity","sigmoid","tanh","softmax"]
        Activation used on the output layer (callable or known string).
    loss : LossFn | Literal["mse","bce","cross_entropy"]
        Loss function (callable or known string).
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        net_length: int,
        net_width: int | list[int],
        hidden_activation: Activation | Literal["relu","sigmoid","tanh","leaky_relu"],
        output_activation: Activation | Literal["identity","sigmoid","tanh","softmax"],
        loss: LossFn | Literal["mse","bce","cross_entropy"],
        seed: Optional[int] = None,
    ) -> None:
        # initial configs
        self.hidden_activation = self._resolve_activation(hidden_activation)
        self.output_activation = self._resolve_activation(output_activation)
        self.loss_fn           = self._resolve_loss(loss)
        self.seed              = seed
        self.input_size        = input_size
        self.output_size       = output_size

        self.net_length = net_length
        self.net_width  = net_width if isinstance(net_width, list) else [net_width] * net_length

    def _resolve_activation(self, fn):
        if callable(fn):
            return fn

        lut = {
            "relu":        lambda x: np.maximum(0, x),
            "sigmoid":     lambda x: 1.0 / (1.0 + np.exp(-x)),
            "tanh":        np.tanh,
            "leaky_relu":  lambda x, a=0.01: np.where(x > 0, x, a * x),
            "identity":    lambda x: x,
            "softmax":     lambda x: (
                (lambda z: (np.exp(z) / np.sum(np.exp(z), axis=1, keepdims=True)))(
                    x - np.max(x, axis=1, keepdims=True)
                )
            ),
        }
        if isinstance(fn, str) and fn in lut:
            return lut[fn]
        raise ValueError(f"Unknown activation: {fn!r}")

    def _resolve_loss(self, fn):
        if callable(fn):
            return fn

        def _mse(y, t):
            return np.mean((y - t) ** 2)

        def _bce(y, t, eps=1e-12):
            y = np.clip(y, eps, 1 - eps)
            return -np.mean(t * np.log(y) + (1 - t) * np.log(1 - y))

        def _cross_entropy(y, t, eps=1e-12):
            y = np.clip(y, eps, 1 - eps)
            # allow class indices or one-hot targets
            if t.ndim == 1 or (t.ndim == 2 and t.shape[1] == 1):
                t = t.reshape(-1)
                return -np.mean(np.log(y[np.arange(y.shape[0]), t]))
            return -np.mean(np.sum(t * np.log(y), axis=1))

        lut = {
            "mse": _mse,
            "bce": _bce,
            "cross_entropy": _cross_entropy,
        }
        if isinstance(fn, str) and fn in lut:
            return lut[fn]
        raise ValueError(f"Unknown loss: {fn!r}")

    def _initialize_weights(self, X):
        self.weights = []

        if self.seed is not None:
            np.random.seed(self.seed)

        # First layer: input to first hidden
        self.weights.append(
            np.random.uniform(-1, 1, size=(X.shape[1], self.net_width[0]))
        )

        # Hidden layers
        for layer_pos in range(1, self.net_length):
            in_size = self.net_width[layer_pos - 1]
            out_size = self.net_width[layer_pos]
            self.weights.append(
                np.random.uniform(-1, 1, size=(in_size, out_size))
            )

        # Output layer: last hidden to output
        self.weights.append(
            np.random.uniform(-1, 1, size=(self.net_width[-1], self.output_size))
        )

    def _forward_pass(self, X: np.ndarray) -> tuple[list[np.ndarray], list[np.ndarray]]:
        if not hasattr(self, "weights") or not self.weights:
            self._initialize_weights(X)

        activations = [X]
        pre_activations = []

        # hidden layers
        for W in self.weights[:-1]:
            z = activations[-1] @ W
            pre_activations.append(z)
            a = self.hidden_activation(z)
            activations.append(a)

        # output layer
        z = activations[-1] @ self.weights[-1]
        pre_activations.append(z)
        a = self.output_activation(z)
        activations.append(a)

        return activations, pre_activations

In [38]:
# forward pass with softmax + cross-entropy (class indices)
X = np.random.randn(4, 3)
y_idx = np.array([0, 2, 0, 2])

net = ANN(
    input_size=3,
    output_size=3,
    net_length=2,
    net_width=[5, 4],
    hidden_activation="relu",
    output_activation="softmax",
    loss="cross_entropy",
    seed=123,
)

acts, pre = net._forward_pass(X)
yhat = acts[-1]

print("Shapes ->",
      "X:", X.shape,
      "yhat:", yhat.shape,
      "num layers (weights):", len(net.weights))

print("Row sums (softmax):", np.round(yhat.sum(axis=1), 6))
print("Loss:", net.loss_fn(yhat, y_idx))

Shapes -> X: (4, 3) yhat: (4, 3) num layers (weights): 3
Row sums (softmax): [1. 1. 1. 1.]
Loss: 0.9847953461577265
