TLDR: This repository is an attempt to understand concepts and develop a jax
based implementation of Koopman operator for Neural Network Optimization, an alternative to purely gradient based optimization.
Optimization of neural network parameters can be visualized as a dynamical system in a discrete temporal domain. The Koopman operator is a powerful tool in the study of dynamical systems, with a recent resurgence in data-driven Koopman operator theory. It allows for linear dynamics in infinite dimensions that is equivalent to the non-linear dynamics in the finite state space. In this report we will explore the work of Dogra et al, which applies Koopman operator theory in neural network optimization. We will lay the theoretical groundwork, implement their method and discuss the validity of this method using jax
.
The gradient descent algorithm is given as:
We can make the following observations:
- Neural network optimization is analogous to a dynamical system with weights
$\theta \in \Theta$ describing the state of the system. - The loss function induces a potential energy field in the domain of weights. This is a scalar field, and the weights where the minimum of this field lies, are the desired optima.
- This field exerts a force with magnitude proportional to the magnitude of gradient of the field in the direction opposite to the gradient at any point.
- Following this gradient with appropriate dynamics `may' allow the weights to reach the required optima autonomously.
Additional Notes: Check this out, for a visual guide.
Stochastic gradient descent and mini-batch gradient descent are effective in getting out of inflections and shallow minimums and are also more robust against the initial state choice. They achieve this by sampling the data points on which loss is calculated each iteration. This can be visualized in the dynamical system view, by imagining a stochastic field of potential energy, since the loss term changes a bit each iteration. The loss field is a random field because the data originates from the same distribution. Consider the general case of a mini-batch gradient descent of the form:
Where,
Momentum on the other hand improves the optimization efficacy by changing the dynamics altogether. It achieves this by making the acceleration proportional to the force exerted by the field, unlike gradient descent which relies on velocity being proportional to the force. This modification results in `inertia', i.e. the velocity does not become zero the moment the force (gradient) becomes zero. This inertia allows the optimizer to escape inflections and shallow minimas. Consider the following dynamical system, where the acceleration is proportional to the net force which includes the force exerted by the potential energy field and a drag force against the movement:
This can be written as a ODE in the standard form using a new variable for velocity
The Koopman operator theory involves defining observables, which are functions of the state, and then identifying the dynamics of these observables in their infinite dimensional function space. Under some constraints, these observables follow linear dynamics which can be described a linear operator colloquially called the Koopman operator. In this section we follow the description of the original paper and describe the Koopman operator treatment for a discrete time dynamical system.
Consider a
The authors make 3 choices towards the implementation of the given idea.
-
The choice of observables: The authors choose the identity function
$\mathbf{I}$ as choice of observable for the Koopman treatment. This allows directly evaluating the dynamics of the weights, since the dynamics of the observables are the dynamics of the state.$^{\dagger}$ -
Choice of Koopman operator approximation method: To approximate the Koopman operator, the authors use the finite section method, which uses two shifted matrix of all collected observables, essentially DMD.
-
Computational compromise: The theoretically coherent thing to do would be to record every weight of the neural network as a column of the matrix, and subsequently apply the finite section method on them. The authors posit the computational complexity of this might be be intractable due to the size of the neural network and thus provide a spectrum of sub-optimal approximation. This spectrum spans building the matrix (a scalar) for each weight, for each node, for each layer or for the entire network (the right thing to do). The authors use node-wise dynamics justifying it with the trade-off against computational complexity. Here a node, refers to all the weights from a layer
$i$ used for a single activation in layer$i+1$ .
We test the method on training a simple feedforward MNIST classfier. The results do not match the results of that of the authors.
For 1 epoch comparison of ADADelta against Node-wise Koopman and Layer-wise Koopman respectively:
Figure 1: Nodewise Koopman vs ADADelta for 1 epoch
Figure 2: Layerwise Koopman vs ADADelta for 1 epoch
For 2 epoch comparison of ADADelta against Node-wise Koopman and Layer-wise Koopman respectively: