# Seminar 3. Loss surfaces (advanced task)

One of the active research fields in deep learning is study of loss surfaces. Such works focus on several questions:
* What is the shape of the loss surface? Do loss optima look like "basins of attraction" (point optima)? Or do they form connected components at the bottom of the landscape?
* Why do neural networks actually learn so well? Why do local optima found by SGD generalize well?
* How does SGD traverse highly nonlinear, nonconvex loss surface?

Today you will conduct several experiments following several recently published papers to shed the light on these questions.

The task consists of three __independent__ parts. You may do them in any order.

In [None]:
import lab_cnn_solution as code
import torch.nn as nn
import torch
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
import os
import losssurf_lab_utils as utils

Module lab_cnn_solution contains the solution for the basic part of the seminar (training CNN on MNIST). You will use this code. The function below trains a model using this code:

In [None]:
# it is _highly_ recommended to do the task on GPU
device = torch.device('cuda') # change to cpu if needed
train_loader, test_loader = code.get_mnist()
criterion = nn.CrossEntropyLoss()

In [None]:
def learn_one_model(num_epochs = 10, path_to_save = None):
    """
    learns 1 CNN on MNIST and returns learned model
    :num_epochs: int
    :path_to_save: if not None, 
                   weights after _each_ epoch are saved to the corresponding folder
                   if None, weights are not saved
    """
    if not os.path.exists(path_to_save):
        os.mkdir(path_to_save)
    cnn = code.CNN().to(device)
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    train_log, train_acc_log, val_log, val_acc_log = \
        code.train(cnn, optimizer, train_loader, test_loader, criterion, \
                   num_epochs, device, 0, path_to_save)
    return cnn

__Technical code__

__Disclamer:__ you will use this code in all three parts of the task, however you will only need to modify it yourself in the part "Learning low-loss valley ...". In two other parts, you only need to run the correspoding cell.

To study the loss surface, we need to be able to reparametrize the weights of the neural network, i. e. compute the weights from some learnable values and backpropagate through the weights. The class below implements this logic. Take a look at this code.

This code also allows the visualization of the loss landscape around optima.

In [None]:
class Base(nn.Module):
    """
    This class separates logic and parameters of the net:
    it stores params in param_net and forwards through net without params (called logic_net)
    
    Overriding methods,
    it is possible to implement diverge weight reparametrization schemes
    
    terminology:
    * weight: tensor that is used during forwad pass through net
              it is composed from params, supports and indexes
              via parametrization scheme
    * param: learnable tensor, torch.nn.Parameter
    * support: already learned tensor
    * index: random tensor, allows walking through surface
    """
    def __init__(self, net_class, supports_files=[], num_param_nets=1,
                 net_args=(), net_kwargs={}):
        super(Base, self).__init__()
        self.num_param_nets = num_param_nets
        self.net_class, self.net_args, self.net_kwargs = \
               net_class, net_args, net_kwargs
        self.logic_net = net_class(*net_args, **net_kwargs)
        self.clear_net_params(self.logic_net, params_to_buffers=False)
        # delete everything learnable from self.logic_net
        self.load_supports(supports_files)
        self.create_params()
        self.create_index()
        self.freeze_index = False
    
    def clear_net_params(self, net, params_to_buffers=False):
        """
        del all params in net (as it shouldn't store params)
        and remember what we deleted so we can assign it in the future
        """
        def clear_net_params_rec(net_module, structure):
            structure["items"] = []
            items = {}
            structure["children"] = {}
            for name, p in net_module._parameters.items():
                structure["items"].append(name)
                if p is not None:
                    items[name] = p.data if params_to_buffers else \
                                  torch.zeros_like(p.data)
                else:
                    items[name] = None
            net_module._parameters.clear()
            for name in items:
                net_module.register_buffer(name, items[name])
            for mname, child in net_module._modules.items():
                structure["children"][mname] = {}
                clear_net_params_rec(child, structure["children"][mname])
        self.structure = {}
        clear_net_params_rec(net, self.structure)

    def load_supports(self, supports_files):
        """
        descedants may override this method
        """
        self.support_nets = nn.ModuleList()
        for file in supports_files:
            if type(file) == str:
                support_state_dict = torch.load(file)
            else:
                support_state_dict = file # state dict is passed instead of file
            support_net = self.net_class(*self.net_args, **self.net_kwargs)
            support_net.load_state_dict(support_state_dict)
            self.clear_net_params(support_net, params_to_buffers=True)
            self.support_nets.append(support_net)
        
    def create_params(self):
        """
        descedants may override this method
        """
        self.param_nets = nn.ModuleList(\
                        [self.net_class(*self.net_args, **self.net_kwargs)\
                         for _ in range(self.num_param_nets)])
        # params are stored within param_nets
        
    def create_index(self):
        """
        descedants may override this method
        """
        self.register_buffer("index", None)
    
    def process_weights_item(self, param_list, support_list=[]):
        """
        descedants may override this method
        """
        return param_list[0] + 0
        
    def gen_index(self):
        """
        descedants may override this method
        """
        pass
        
    def gen_weights(self):
        """
        assignes new generated values to weights of self.logic_net
        generated values depend on params, supports and index
        descedants may override this method
        """
        def gen_weights_rec(param_nets_modules,
                            support_nets_modules,
                            logic_net_module,
                            structure):
            for name in structure["items"]:
                param_list = [pnm._parameters[name] for pnm in param_nets_modules]
                support_list = [snm._buffers[name] for snm in support_nets_modules]
                logic_net_module._buffers[name] = \
                    self.process_weights_item(param_list, support_list) if support_list[0] is not None\
                    else None
            for mname, lnm in logic_net_module._modules.items():
                pnms = [pnm._modules[mname] for pnm in param_nets_modules]
                snms = [snm._modules[mname] for snm in support_nets_modules]
                gen_weights_rec(pnms, snms, lnm, structure["children"][mname])
        
        gen_weights_rec(self.param_nets, 
                        self.support_nets, 
                        self.logic_net,
                        self.structure)
    
    def train(self, mode=True):
        super(Base, self).train(mode) 
        if mode:
            self.freeze_index = False
        
    
    def forward(self, x):
        if not self.freeze_index:
            self.gen_index()
        self.gen_weights()
        return self.logic_net(x)
    

Class Base only separates parameters from the logic but works as usual neural network. The class below inherits from Base and changes parametrization. Class Segment varies weights along segment \[a, b\] where a and b are fixed endpoints:
$$w = a  t + b (1-t), \quad t \in [0, 1]$$
This class doesn't have learnable parameters. 

In [None]:
class Segment(Base):
    def process_weights_item(self, param_list, support_list=[]):
        t = self.index.item()
        return (1-t)*support_list[0]+t*support_list[1]
    
    def create_index(self):
        self.register_buffer("index", torch.zeros(1))
    
    def gen_index(self):
        self.index.uniform_()

### Computing eigenvectors and eigenvalues of the Hessian
In this task, you will need 1 trained CNN on MNIST (no saved weights are needed). Train it:

In [None]:
cnn = ...

[Sagun et. al, 2017](https://arxiv.org/abs/1706.04454) study the Hessian near the local optima and particularly its eigenvalues and eigenvectors. If the eigenvalue is large, it means that loss changes a lot in the direction of corresponding eigenvector. On the contrary, if eigenvalue is small, the loss is flat in the corresponding direction. Experiments show that near optima there are several large eigenvalues and a lot of small eigenvalues. From this the authors conclude that optima form large connected components at the bottom of the landscape.
The authors also note that the number of large eigenvalues usually equals the number of the classes. 

We suggest you to compute the eigenvalues and eigenvectors of the Hessian near the optima and to check the hypothesis described above. 

To compute the eigenvalues and eigenvectors, one may use [Lanczos algorithm](https://en.wikipedia.org/wiki/Lanczos_algorithm). This algorithm relies on the computation of Hessian vector product. We will use the implementation of the algorithm from scipy (eigsh). You only have to pass the function that computes the Hessian vector product.

This product can be computed in common autograd packages as follows. Let $v$ be a pre-computed numerical vector (such as the gradient). One first computes the scalar $a = \nabla L^T v$, and then takes the gradient of this expression, resulting in $\nabla a = Hv$. Let's implement this in Pytorch.

In [None]:
from scipy.sparse.linalg import LinearOperator, eigsh

In [None]:
def eval_hess_vec_prod(vec, params, net, criterion, dataloader, device):
    """
    Evaluate product of the Hessian of the loss function with a direction vector "vec".
    The product result is saved in the grad of net.
    Args:
        vec: a list of tensor with the same dimensions as "params".
        params: the parameter list of the net.
        net: model with trained parameters.
        criterion: loss function.
        dataloader: dataloader for the dataset.
        device: cuda or cpu
        
    Hint:
    1. Perform usual pass through dataloader and compute the loss for each minibatch.
    Also, perform backard pass for each mini batch
    (NOT cleaning gradient as we will need the graient computed on the FULL dataset)
    Use grad_f = torch.autograd.grad(loss, inputs=params, create_graph=True)
    2. After that, loop parallelly through grad_f and vec and sum g * v.
    3. Finally, perform one more backward pass.
    """
    net.zero_grad() # clears grad for every parameter in the net
    ### your code here

The following functions passes your pytorch implementation of hess-vec product to scipy's eigsh.

In [None]:
def get_hessian_eigs(cnn, dataloader, criterion, device, num_max=20):
    params = [p for p in cnn.parameters()]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        vec = utils.npvec_to_tensorlist(vec, params, device)
        eval_hess_vec_prod(vec, params, cnn, criterion, dataloader, device)
        return utils.gradtensor_to_npvec(cnn)

    A = LinearOperator((N, N), matvec=hess_vec_prod)
    eigvals, eigvecs = eigsh(A, k=num_max, tol=1e-2)
    return eigvals, eigvecs

In [None]:
eigvals, eigvecs = get_hessian_eigs(cnn, train_loader, criterion, device, num_max=20)

Now let's plot the eigenvalues:

In [None]:
plt.scatter(np.arange(len(eigvals)), eigvals)
plt.xlabel("Number of eigenvalue")
plt.ylabel("Eigenvalue")

Also, let's visualize the loss along lines $w_* + v_0 t, \, t \in [-0.1, 0.1]$ and $w_* + v_k t, \, t \in [-0.1, 0.1], k >> 0$.

In [None]:
vecs_dict = {0:eigvecs[:, -1], -1:eigvecs[:, 0]}
v = utils.sd2tensor(cnn.state_dict()).cpu().numpy()
width = 0.1
segments = {}
for lab, ev in vecs_dict.items():
    l = utils.tensor2sd(torch.from_numpy(v-width*ev), cnn.state_dict())
    r = utils.tensor2sd(torch.from_numpy(v+width*ev), cnn.state_dict())
    segment = Segment(code.CNN, [l, r], 0).to(device)
    segments[lab] = segment
index_grid = torch.linspace(0, 1, 21).to(device)
res = utils.plot_along_manifold(segments, train_loader, test_loader, code.evaluate_loss_acc, \
                          criterion, device, index_grid, plot=True)

Some other works study eigenvalues and eigenvectors along SDG path. 
[This anonymus work](https://openreview.net/pdf?id=ByeTHsAqtX) find that the subspace spanned by top eigenvectors (corresonding to a few largest eigenvalues) doesn't change a lot after some epoch. Also, they find that the gradient mostly lies in this top subspace. You may conduct corresponding experiments to repeat their results.

[Jastrzębski et al, 2019](https://openreview.net/pdf?id=SkgEaj05t7) find that during training, maximum eigenvalue firstly raises and then starts decreasing. It means that at the beginning of the training, the loss curvature gets sharper and sharper, and from some moment SGD starts fluctuating around optima (and actually cannot come closer to the optimum). They relate this moment to the batch size and learning rate. You may also repeat this experiment.

### Learning low-loss valley between two independently found optima
In this part, you will need 2 CNNs learned on MNIST (with saved weights). Train them:

In [None]:
### your code here

[Garipov et al., 2018](https://arxiv.org/pdf/1802.10026.pdf) learn paths between two independetly learned optima. Along these paths, the loss is low and near-constant. This shows that the optima of the loss lie on a connected manifold rather than being separated from each other.

The path is learned in a form of simple polychain ($a$ and $b$ represent two learned weight vectors, $\theta$ is a leanable bench):
$$
w = \begin{cases} 2(t \theta + (0.5-t)a), \quad  0 \leq t \leq 0.5  \\ 
2((t-0.5)b + (1-t)\theta),\quad  0.5 < t \leq 1\end{cases}
$$
The model is learned using SGD. When processing 1 minibatch, one firstly samples $t \in [0, 1]$, then sets weights and performs forward pass through CNN. Gradients are taken w. r. t. $\theta$. 

To implement this algortithm, you only need to implment reparametrization scheme in class Polychain. The you may use functions from module  lab_cnn_solution totrain polychain.

In [None]:
class Polychain(Base):
    def process_weights_item(self, param_list, support_list=[]):
        ### your code here
    
    def create_index(self):
        ### your code here
    
    def gen_index(self):
        ### your code here

In [None]:
fn1 = ... # specify endpoint 1
fn2 = ... # specify endpoint 2
polychain = Polychain(code.CNN, [fn1, fn2]).to(device)
segment = Segment(code.CNN, [fn1, fn2]).to(device)

We will compare learned polychain with straightforward segment $[a, b]$:

In [None]:
### your code here
### learn polychain as usual model in pytorch

Let's now plot quality along learned polychain $[a, \theta, b]$ and along segment $[a, b]$.

In [None]:
# plot quality along two paths
manifolds = {"segment":segment, "polychain":polychain}
index_grid = torch.linspace(0, 1, 21).to(device)
res = utils.plot_along_manifold(manifolds, train_loader, test_loader, code.evaluate_loss_acc, \
                          criterion, device, index_grid, plot=True)

As for further experiments, you may try to connect different types of optima (for example, optima learned with small and large batch size) and to check this connectivity hypothesis for other networks / trainng procedures (e. g. use L2 regularization or dropout). [Gotmare et al., 2018](https://arxiv.org/pdf/1806.06977.pdf) conduct such an analysis.

### Visualize SGD path using PCA
In this part, you will need 1 CNN trained on MNIST with weights saved per epoch. Train it:

In [None]:
path_to_save = ...
cnn = ...

[Lorch et al., 2016](https://icmlviz.github.io/icmlviz2016/assets/papers/24.pdf) visualize the trajectory of SGD using PCA. Although training happens in high-dimensional space, its projection onto 2d plane may distinctive properties of the training process. Let's make such a visualization as well.

In [None]:
from sklearn.decomposition import PCA

In [None]:
fns = [path_to_save+"/model_ep%d.cpt"%epoch for epoch in range(num_epochs)]

In [None]:
ws = np.array([utils.sd2tensor(torch.load(fn)).cpu().numpy() for fn in fns])

In [None]:
ws.shape

In [None]:
### your code here
# learn PCA on ws and select two principal components
mean = ...
direction1 = ...
direction2 = ...

In [None]:
### print explained variance ratio
### your code here

In [None]:
class Plane2dConnection(Base):
    def __init__(self, net_class, mean, direction1, direction2,
                 net_args=(), net_kwargs={}):
        super(Plane2dConnection, self).__init__(net_class, [mean, direction1, direction2],\
                                                  0, net_args,\
                                                  net_kwargs)
    
    def process_weights_item(self, param_list, support_list=[]):
        t1 = self.index[0].item()
        t2 = self.index[1].item()
        return support_list[0] + t1*support_list[1] + t2*support_list[2]
    
    def create_index(self):
        self.register_buffer("index", torch.zeros(2))
    
    def gen_index(self):
        self.index.uniform_().mul_(2).add_(-1)

In [None]:
example_sd = code.CNN().state_dict()
plane = Plane2dConnection(code.CNN, \
                      utils.tensor2sd(torch.from_numpy(mean).to(device),\
                                            example_sd),\
                      utils.tensor2sd(torch.from_numpy(direction1).to(device),\
                                            example_sd),\
                      utils.tensor2sd(torch.from_numpy(direction2).to(device),\
                                            example_sd)).to(device)
components = np.array([direction1, direction2])
utils.plot_2d(components, mean, ws, plane, device, train_loader, test_loader, \
                code.evaluate_loss_acc, criterion)

We suggest you to repeat this procedure for different num_epochs because the principal components differ a lot for small and large num_epochs.

However, you shouldn't rely so much on this type of analysis. [Antognini and Sohl-Dickstein, 2018](https://arxiv.org/pdf/1806.08805.pdf) theoretically analyse the properties of PCA applied to high-dimensional random walks. They find that explained variance and projection of the trajectory look similar for SGD and random walk.