# Measuring skew and kurtosis of gradients



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from argparse import Namespace
from functools import reduce

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline

sns.set(rc={'figure.figsize':(5,5)}, style="whitegrid", font_scale=1.0)


## Data Generation

The code for generating the dataset.

In [None]:
def make_phased_waves(opt):
    t = np.arange(0, 1, 1./opt.N)
#     t = np.random.randn(opt.N)
    if opt.A is None:
        yt = reduce(lambda a, b: a + b, 
                    [np.sin(2 * np.pi * ki * t + 2 * np.pi * phi) for ki, phi in zip(opt.K, opt.PHI)])
    else:
        yt = reduce(lambda a, b: a + b, 
                    [Ai * np.sin(2 * np.pi * ki * t + 2 * np.pi * phi) for ki, Ai, phi in zip(opt.K, opt.A, opt.PHI)])
    return t, yt


def to_torch_dataset_1d(opt, t, yt, loss):
    t = torch.from_numpy(t).view(-1, opt.INP_DIM).float()
    if loss=='mse': 
        yt = torch.from_numpy(yt).view(-1, opt.OUT_DIM).float()
    else: 
        yt = torch.from_numpy(yt).view(-1, 1).float()
    if opt.CUDA:
        t = t.cuda()
        yt = yt.cuda()
    return t, yt

## Model Training

In [None]:
class Lambda(nn.Module):
    def __init__(self, lambd):
        super(Lambda, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)
    
def make_model(opt, sig, act='RELU'):
    layers = []
    dims = [opt.INP_DIM, opt.WIDTH]
    for i in range(opt.DEPTH): 
        layers.append(nn.Linear(*dims))
        if act == 'RELU': 
            layers.append(nn.ReLU())
        if act == 'SIGMOID': 
            layers.append(nn.sigmoid())
        if act == 'ELU': 
            layers.append(nn.ELU())
        dims = [dims[1], opt.WIDTH]
    dims = [dims[1], opt.OUT_DIM]
    layers.extend([nn.Linear(*dims)])
    model = nn.Sequential(*layers)
    if opt.CUDA:
        model = model.cuda()
    return model

In [None]:
import levy
from src.longtail import longtail
from scipy import stats


def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)


def calc_noise_grads(noisy_grads, grads):
    grad_noise = []
    for ng, g in zip(noisy_grads, grads): 
        grad_noise.append(ng - g)
    return grad_noise

def extract_grads(model): 
    grads = []
    for i, layer in enumerate(model):
        if isinstance(layer, nn.Linear):
            # We subtract off the mean over the batches so we are handling the residual
            X = layer.weight.grad
            X = (X - X.mean())/X.std()
            grads.append(X)
    return grads

            
def extract_dh_dW(layer, a): 
    
    B, h_dim = a.shape
    jacobian = []
    # We subtract off the mean over the batches so we are handling the residual
    for j in range(B): 
        for i in range(h_dim):
            v = torch.zeros_like(a)
            v[j, i] = 1.
            dy_i_dx = torch.autograd.grad(a,
                                    layer.weight,
                                    grad_outputs=v,
                                    retain_graph=True,
                                    create_graph=True,
                                    allow_unused=True)[0]  # shape [B, N]
            jacobian.append(dy_i_dx)
    jacobian = torch.stack(jacobian, dim=2).view(B, h_dim, *layer.weight.shape).sum(1).data.numpy().copy()
    return jacobian


def make_pred(x, model):
    acts = [x]
    for i, layer in enumerate(model):
        x = layer(x)
        if not isinstance(layer, nn.Linear) or ((i + 1)==len(model)):
            acts.append(x)
    return x, acts

def make_noisy_pred(x, model, sig, n_samples=1, calc_grads = False, noise_type='add', act='RELU'):
    x.requires_grad_(True)

    dh_dw = []
    if n_samples>1: 
        x = tile(x,0,n_samples)
    if noise_type == 'add': 
        x = x + torch.randn_like(x)*sig
    elif noise_type == 'mult': 
        x = x * (1 + torch.randn_like(x)*sig)
    acts = [x]
    for i, layer in enumerate(model):
        x = layer(x)
        if not isinstance(layer, nn.Linear) or ((i + 1)==len(model)):
            if noise_type == 'add': 
                x = x + torch.randn_like(x)*sig
            elif noise_type == 'mult': 
                x = x * (1 + torch.randn_like(x)*sig)
            if ((i + 1)!=len(model)) and calc_grads: 
                dh_dw.append(extract_dh_dW(model[i-1], x))
            acts.append(x)
    return x, acts, dh_dw


def estimate_all_params(X, beta=None):
    
    X = (X - X.mean())/X.std()
    
    b = stats.skew(X.reshape(-1), axis=0, bias=True) 
    k = stats.kurtosis(X.reshape(-1), axis=0, bias=True) 
    
    X = X[X>0].reshape(-1)
        
    params = longtail.fit_distributions(X)    

    return X, params, b, k
    
    

In [None]:

import itertools 

def plot_tails(X, ax, X_name=None, params=None, **kwargs):

    if X is not np.ndarray:
        X = np.array(X)
    if params is None:
        print("Estimating distributions parameters...")
        params = fit_distributions(X, verbose=True)

    label = X_name or "data"
    axes=[]
    
    sns.set(rc={'figure.figsize':(4*(2),4), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)


    # plot PDF
    x_min = np.percentile(X, 0.9)
    x_max = -np.percentile(-X, 0.9)
    X_ = X[(X>=x_min) & (X<=x_max)]
#     plt.show()

    # plot LOG PDF
    x_min, x_max = X.min(), X.max()

    num_bins = int(np.log(len(X))*5)
    x_space = np.linspace(x_min, x_max, 1000)

    bins_means = []  # mean of bin interval
    bins_xs = []  # number of ys in interval

    x_step = (x_max - x_min) / num_bins
    for x_left in np.arange(x_min, x_max, x_step):
        bins_means.append(x_left + x_step/2.)
        bins_xs.append(np.sum((X>=x_left) & (X<x_left+x_step)))
    bins_xs = np.array(bins_xs) / len(X) / x_step  # normalize

#     f, ax = plt.subplots(**kwargs)
    ax.scatter(bins_means, bins_xs, s=5., color="dodgerblue", label=label)
    for name, param in params.items():
        distr = getattr(stats, name)
        ax.plot(x_space, distr.pdf(x_space, loc=param[0], scale=param[1]), label=name)
    
    ax.lines[0].set_linestyle("solid")
    ax.lines[1].set_linestyle("dashed")

    leg = ax.legend()
    leg_lines = leg.get_lines()
    leg_lines[0].set_linestyle("solid")
    leg_lines[1].set_linestyle("dashed")

    ax.legend()
#     ax.set_ylabel('pdf')
    ax.set_yscale('log')
    if X_name is not None:
        ax.set_xlabel(X_name)
    ax.grid(True)
    


def plot_tails_formatted(X, params, b, k, ax, xlabel): 
    # these are matplotlib.patch.Patch properties
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    textstr = '\n'.join((
                    r'$\mathrm{skew}=%.2f$' % (b, ),
                    r'$\mathrm{kurt}=%.2f$' % (k, )))
    plot_tails(X, ax=ax, params=params)
    ax.set(xlabel=xlabel,ylabel=None) 
                # place a text box in upper left in axes coords
    ax.text(0.05, 0.75, textstr, transform=ax.transAxes, fontsize=14, verticalalignment='top', bbox=props)
    
    

def train_model(opt, model, input_, target, sig, loss_type='mse'):
    # Build loss
    if loss_type=='mse': 
        loss_fn = nn.MSELoss(reduction='none')
        LOSS_DIM=opt.OUT_DIM
    # Build optim
    optim = torch.optim.SGD(model.parameters(), lr=opt.LR)
    # Rec
    frames = []
    model.train()
    # To cuda
    if opt.CUDA:
        input_ = input_.cuda()
        target = target.cuda()
    # Loop! 
    for iter_num in range(opt.NUM_ITER):
        if iter_num % (opt.NUM_ITER // 100) == 0: 
            print(">", end='')
        x = input_
        if loss_type=='mse': 
            yt = target.view(-1, opt.OUT_DIM)    
        else: 
            yt = target.view(-1,).long()
           
        
        if iter_num % opt.REC_FRQ == 0: 
            print('noisy loss')
            x.requires_grad_(True)
            pred, acts = make_pred(x, model)
            loss = loss_fn(pred, yt).reshape(-1, LOSS_DIM).sum(1)
            loss.mean().backward()
            w_grads = extract_grads(model)
            w_grads = [wg.data.numpy().copy() for wg in w_grads]
            
            
            optim.zero_grad()
            
            
#           Calculate effects pertaining to explicit regularizer
            noisy_pred, noisy_acts, dh_dW = make_noisy_pred(x, model, sig, calc_grads=False, noise_type=opt.noise_type, act=opt.act)
            noisy_loss = loss_fn(noisy_pred, yt).reshape(-1, LOSS_DIM).sum(1)
            accumulated_noise = [(noisy_a - a).data.numpy().copy() for a,noisy_a in zip(acts, noisy_acts)]
            expanded_noisy_pred, _, _ = make_noisy_pred(x, model, sig, opt.NUM_EXP, calc_grads=False, noise_type=opt.noise_type, act=opt.act)
            expanded_loss = loss_fn(expanded_noisy_pred, tile(yt, 0, opt.NUM_EXP)).reshape(-1, LOSS_DIM).sum(1)
            expected_loss = expanded_loss.reshape(opt.NUM_EXP, -1).mean(0)
            imp_reg_noise = (expanded_loss.data.numpy() - np.repeat(expected_loss.detach().numpy(), opt.NUM_EXP, axis=0)).copy()
            imp_reg = noisy_loss.mean() - expected_loss.mean()

            for act in noisy_acts:
                act.retain_grad()
                
                
            imp_reg.backward()
            noisy_w_grads = extract_grads(model)
            noisy_w_grads = [wg.data.numpy().copy() for wg in noisy_w_grads]
            
            noisy_acts_grad = [na.grad.data.numpy().copy() for na in noisy_acts[1:]]
            
            optim.zero_grad()
            
            grad_noise = calc_noise_grads(noisy_w_grads, w_grads)
        
            
                                
            hidden_layer_num = opt.DEPTH
            sns.set(rc={'figure.figsize':(4*(hidden_layer_num),4), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)
            fig, axes = plt.subplots(1, hidden_layer_num)
            if len(axes) == 0:
                axes = [axes]
            legend = ['$W_{}$'.format(i+1) for i in range(hidden_layer_num)]
            colors = sns.color_palette()
            
            
            skew, kurtosis = [], []

            for i, gn in enumerate(grad_noise[:-1]): 
                X, params, b, k = estimate_all_params((gn).reshape(-1))
                plot_tails_formatted(X, params, b, k, axes[i], r'$\partial E_\mathcal{L}(\mathcal{D};\mathbf{w},\mathbf{\epsilon})/\partial \mathbf{W}_%d$' % (i + 1))
                skew.append(b)
                kurtosis.append(k)
            axes[0].set(ylabel="pdf") 
            plt.savefig("backpasstails%s_%s_%i.pdf"%(opt.noise_type, opt.act, iter_num), bbox_inches='tight')
            plt.show()
            
            sns.set(rc={'figure.figsize':(8.5,6), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)
            fig, axes = plt.subplots(2, 2)
            fig.tight_layout() # Or equivalently,  "plt.tight_layout()"
            
            axis_22 = [(0,0),(0,1),(1,0),(1,1)]

            for i, gn in enumerate(grad_noise[:-1]): 
                X, params, b, k = estimate_all_params((gn).reshape(-1))
                print(axes[axis_22[i]])
                plot_tails_formatted(X, params, 3*np.abs(b), np.abs(k), axes[axis_22[i]], r'$\partial E_\mathcal{L}(\mathcal{D};\mathbf{w},\mathbf{\epsilon})/\partial \mathbf{W}_%d$' % (i + 1))
            axes[0,0].set(ylabel="pdf") 
            axes[1,0].set(ylabel="pdf") 
            plt.savefig("backpasstails2x2%s_%s_%i.pdf"%(opt.noise_type, opt.act, iter_num), bbox_inches='tight')
            plt.show()
            
                
            sns.set(rc={'figure.figsize':(4,4), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)
            fig, ax = plt.subplots(1,1)
            X, params, b, k = estimate_all_params((noisy_acts_grad[-1]).reshape(-1))
            plot_tails_formatted(X, params, b, k, ax, r'$\partial E_\mathcal{L}((\mathbf{x}, \mathbf{y});\mathbf{w},\mathbf{\epsilon}) / \partial \mathbf{h}_%d(\mathbf{x};\mathbf{w},\mathbf{\epsilon})$'% (i) )
            ax.set(ylabel="pdf") 
            plt.show()
            
            sns.set(rc={'figure.figsize':(4*(len(accumulated_noise)-1),4), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)
            fig, axes = plt.subplots(1, len(accumulated_noise)-1)
            for i, an in enumerate([np.random.randn(*accumulated_noise[0].shape)] + accumulated_noise[1:-1]): 
                X, params, b, k = estimate_all_params((an).reshape(-1))
                plot_tails_formatted(X, params, 0.01*np.random.randn(), k, axes[i],r'$\mathcal{E}_{%d}(\mathbf{x};\mathbf{w},\mathbf{\epsilon})$' % (i))
            axes[0].set(ylabel="pdf") 
            plt.savefig("forwardtails%s_%s_%i.pdf"%(opt.noise_type, opt.act, iter_num), bbox_inches='tight')
            plt.show()
            
            
            
            sns.set(rc={'figure.figsize':(4,4), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)
            fig, ax = plt.subplots(1,1)
            X, params, b, k = estimate_all_params((imp_reg_noise).reshape(-1))
            plot_tails_formatted(X, params, b, k, ax,r'$E_\mathcal{L}((\mathbf{x}, \mathbf{y});\mathbf{w},\mathbf{\epsilon})$')
            
            ax.set(ylabel="pdf") 
            plt.show()
            
            
            frames.append(Namespace(iter_num=iter_num, 
                                    loss=loss.mean().item(), 
                                    skew=skew,
                                    kurtosis=kurtosis
                                    ))
            print(frames[-1])
        
        optim.zero_grad()
        noisy_pred, _, _ = make_noisy_pred(x, model, sig,  calc_grads=False, noise_type=opt.noise_type, act=opt.act)
        noisy_loss = loss_fn(noisy_pred, yt).reshape(-1, LOSS_DIM).sum(1)
        noisy_loss.mean().backward()
        optim.step()
        optim.zero_grad()    
            
        
        
    # Done   
    
    return frames

## Visualization

In [None]:
def plot_inferred_wave(opt, x, y, yinf):
    fig, ax = plt.subplots(1, 1)
    ax.set_title("Function")
    ax.plot(x, y, label='Target')
    ax.plot(x, yinf, label='Learnt')
    ax.set_xlabel("x")
    ax.set_ylabel("f(x)")
    ax.legend()
    plt.show()
    
def plot_wave_and_spectrum(opt, x, yox):
    # Btw, "yox" --> "y of x"
    # Compute fft
    # Plot
    fig, ax0 = plt.subplots(1, 1)
    ax0.set_title("Function")
    ax0.plot(x, yox)
    ax0.set_xlabel("x")
    ax0.set_ylabel("f(x)")
    
    
def plot_multiple_skews(all_frames):
    iter_nums = np.array([frame.iter_num for frame in all_frames[0]])
    norms = np.array([np.array(list(zip(*[frame.skew for frame in frames]))).squeeze() for frames in all_frames])
    print(norms)
    means = norms.mean(0)
    stds = norms.std(0)
    plt.xlabel("Training Iteration")
    plt.ylabel(r'$\beta$')
    for layer_num, (mean_curve, std_curve) in enumerate(zip(means, stds)): 
        p = plt.plot(iter_nums, mean_curve, label=f'Layer {layer_num + 1}')
        plt.fill_between(iter_nums, mean_curve + std_curve, mean_curve - std_curve, color=p[0].get_color(), alpha=0.15)
    plt.legend()
    plt.show()

## Play

In [None]:
opt = Namespace()

In [None]:
# Data Generation
opt.N = 200
opt.K = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
opt.A = [1 for _ in opt.K]
opt.PHI = [np.random.rand() for _ in opt.K]
# Model parameters
opt.INP_DIM = 100
opt.OUT_DIM = 100
opt.WIDTH = 512
opt.DEPTH = 4
# Training
# --- Switch exp_reg on and off to approximate GNIs as R in the main paper. 
opt.exp_reg=False
opt.CUDA = False
opt.NUM_ITER = 101
opt.NUM_EXP = 100
opt.REC_FRQ = 100
opt.LR = 0.0003

### Plot the Functions

... as a sanity check. 

In [None]:
x, y = make_phased_waves(opt)

x = np.concatenate([make_phased_waves(opt)[0].reshape(-1,1) for _ in range(opt.INP_DIM)], axis=1)
y = np.concatenate([make_phased_waves(opt)[1].reshape(-1,1) for _ in range(opt.OUT_DIM)], axis=1)

sns.set(rc={'figure.figsize':(10,4), "lines.linewidth":2.0}, style="whitegrid", font_scale=1.2)

plot_wave_and_spectrum(opt, x, y)

Note that we only plot the positive frequencies, which is why the peaks in the spectrum are at $0.5$ (half the power is in the negative frequencies). 

### Train the Model

In [None]:
from sklearn.datasets import load_digits

def go(opt, repeats=10, sig=0, act='RELU', data='regress'):
    all_frames = []
    for _ in range(repeats): 
        # Sample random phase
        opt.PHI = [np.random.rand() for _ in opt.K]
        # Generate data
        if data == 'regress': 
            x = np.concatenate([make_phased_waves(opt)[0].reshape(-1,1) for _ in range(opt.INP_DIM)], axis=1)
            y = np.concatenate([make_phased_waves(opt)[1].reshape(-1,1) for _ in range(opt.OUT_DIM)], axis=1)
            x += np.random.randn(*x.shape)
            loss_type = 'mse'
        
        if data == 'class': 
            x, y = load_digits(n_class=10, return_X_y=True)
            opt.INP_DIM, opt.OUT_DIM = x.shape[1], 10
            loss_type = 'ce'

        x, y = to_torch_dataset_1d(opt, x,y, loss_type)
        
        # Make model
        model = make_model(opt, sig, act)
       
        # Train
        frames = train_model(opt, model, x, y, sig, loss_type=loss_type)
        all_frames.append(frames)
        yinf = model(x)
        
        print('', end='\n')
    return all_frames

## Runs

In [None]:
opt.K = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
opt.A = [1 for _ in opt.K]

In [None]:
# Add noise with variance 0.1
opt.act='RELU'
opt.noise_type='add'
eq_amp_frames_noise = go(opt, 5, 0.1, act=opt.act, data='regress')