# Understanding the difficulty of training deep feedforward neural networks
par Xavier Glorot et Yoshua Bengio (2010)

## TODO

- Ajouter un petit set sur lequel calculer les différentes métriques de monitoring (moyenne / std / histogramme des activations et gradients)

## Imports

In [None]:
%load_ext autoreload
%autoreload 2
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
from sklearn.model_selection import train_test_split
from tqdm import tqdm

### Deep Learner

In [None]:
from deep_learner import Tensor
from deep_learner._core.types import Device
from deep_learner.datasets import cifar10, mnist
from deep_learner.metrics.accuracy import accuracy
from deep_learner.nn import (
    SGD,
    CrossEntropyLoss,
    Linear,
    Module,
    Optimizer,
    Sequential,
    Sigmoid,
    Softmax,
    Softsign,
    Tanh,
)
from deep_learner.utils import batch

### Datasets

In [None]:
CIFAR10_DATASET = cifar10()
MNIST_DATASET = mnist()

In [None]:
train_X, val_X, train_Y, val_Y = train_test_split(
    CIFAR10_DATASET[0],
    CIFAR10_DATASET[1],
    test_size=10_000,
    random_state=0,
    shuffle=True,
)

test_X, monitoring_X, test_Y, monitoring_Y = train_test_split(
    CIFAR10_DATASET[2], CIFAR10_DATASET[3], test_size=300, random_state=0, shuffle=True
)

## Models

In [None]:
models = {}

for num_hidden in range(1, 6):
    for activation_fn in (Sigmoid, Softsign, Tanh):
        act_fn_str = activation_fn.__name__.lower()
        models[f"model_{act_fn_str}_{num_hidden}"] = Sequential(
            Linear(3072, 1_000),
            activation_fn(),
            *list(
                chain.from_iterable(
                    (Linear(1_000, 1_000), activation_fn())
                    for _ in range(num_hidden - 1)
                ),
            ),
            Linear(1_000, 10),
            Softmax(),
        )

print(models)

### Hooks to collect activations and gradients statistics

In [None]:
models_statistics = {model_name: {} for model_name in models}


def get_activations(accumulator):
    def hook(module, outputs, *inputs, **kwargs):
        accumulator["act_mean"].append(float(outputs.data.mean()))
        accumulator["act_std"].append(float(outputs.data.std()))

    return hook


def get_gradients(accumulator):
    def hook(grad):
        accumulator["grad_mean"].append(float(grad.data.mean()))
        accumulator["grad_std"].append(float(grad.data.std()))

    return hook

## Hyperparameters

In [None]:
# In the paper, the procedure to tune the learning rate is not explicited, although it
# appears to be some kind of search validated through a validation set,
# therefore we simply set it to 10^-3 for the moment
LEARNING_RATE: float = 1e-3
BATCH_SIZE: int = 10
EPOCHS: int = 10
METRICS_PER_EPOCH: int = 50

## Training code

In [None]:
def train(
    model: Module,
    epochs: int,
    batch_size: int,
    optimizer: Optimizer,
    loss_fn: Module,
    train_X: NDArray,
    train_Y: NDArray,
    monitoring_X: NDArray,
    monitoring_Y: NDArray,
    monitoring_acc: dict[str, dict[str, list[float]]],
    device: Device,
) -> Module:
    model.to(device)

    num_batches = len(train_X) // batch_size + bool(len(train_X) % batch_size)
    monitoring_X = Tensor(monitoring_X).to(device)
    monitoring_Y = Tensor(monitoring_Y).to(device)

    for epoch in tqdm(range(epochs), desc="Epoch:", total=epochs):
        model.train()

        batch_counter = 0
        epoch_loss = Tensor(0)
        epoch_accuracy = Tensor(0)

        for batch_X, batch_Y in batch(train_X, train_Y, batch_size=batch_size):
            optimizer.zero_grad()

            inputs = Tensor(batch_X).to(device)
            labels = Tensor(batch_Y).to(device)

            predictions = model(inputs)

            loss = loss_fn(predictions, labels)

            loss.backward()

            optimizer.step()

            epoch_accuracy += accuracy(
                Tensor(np.argmax(predictions.detach().to(Device.CPU).data, axis=-1)),
                Tensor(np.argmax(labels.detach().to(Device.CPU).data, axis=-1)),
            )

            batch_counter += 1
            epoch_loss += loss.detach().to(Device.CPU)

            if (batch_counter) % (num_batches // METRICS_PER_EPOCH) == 0:
                handles = []
                for child_name, child in model.named_children():
                    if isinstance(child, Linear):
                        if child_name not in monitoring_acc:
                            monitoring_acc[child_name] = {
                                "act_mean": [],
                                "act_std": [],
                                # "grad_mean": [],
                                # "grad_std": []
                            }
                        handles.append(
                            child.register_forward_hook(
                                get_activations(monitoring_acc[child_name])
                            )
                        )
                        # handles.append(child.register_backward_hook(get_gradients(monitoring_acc[child_name])))

                optimizer.zero_grad()

                monitoring_preds = model(monitoring_X)
                monitoring_loss = loss_fn(monitoring_preds, monitoring_Y)

                monitoring_loss.backward()

                for handle in handles:
                    handle.remove()

        print(
            f"[+] [Epoch {epoch + 1}] mean accuracy: {epoch_accuracy.data / batch_counter:.2%}, mean loss: {epoch_loss.data / batch_counter:.6f}"
        )

In [None]:
i = 0
device = Device.CPU

for model_name, model in models.items():
    print(f"[+] Training {model_name}...")
    train(
        model,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        optimizer=SGD(model, LEARNING_RATE),
        loss_fn=CrossEntropyLoss(),
        train_X=train_X,
        train_Y=train_Y,
        monitoring_X=monitoring_X,
        monitoring_Y=monitoring_Y,
        monitoring_acc=models_statistics[model_name],
        device=device,
    )

    test_predictions = model(Tensor(test_X).to(device))
    test_accuracy = accuracy(
        Tensor(np.argmax(test_predictions.detach().to(Device.CPU).data, axis=-1)),
        Tensor(np.argmax(test_Y, axis=-1)),
    )

    print(f"[+] Accuracy on test set: {test_accuracy.data:.2%}\n")

    i += 1
    if i == 3:
        break

In [None]:
colors = ["r", "g", "b", "c", "b"]
markers = ["o", "^", "s", "p", "h"]

x_axis = range(
    0,
    EPOCHS * len(train_X) // BATCH_SIZE,
    len(train_X) // (BATCH_SIZE * METRICS_PER_EPOCH),
)

for model_name, layers_statistics in models_statistics.items():
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))

    for index, (layer_name, layer_statistics) in enumerate(layers_statistics.items()):
        axs[0].errorbar(
            x_axis,
            layer_statistics["act_mean"],
            yerr=layer_statistics["act_std"],
            errorevery=(index, len(layers_statistics)),
            color=colors[index],
            label=layer_name,
        )

    axs[0].legend()
    axs[0].set_xlabel("Number of updates")
    axs[0].set_ylabel("Layer activations mean values")
    axs[0].set_title("Activations")

    axs[1].legend()
    axs[1].set_xlabel("Number of updates")
    axs[1].set_ylabel("Layer gradients mean values")
    axs[1].set_title("Gradients")

    fig.suptitle(f"Statistics for {model_name}")
    fig.show()