# Marginalising out the Implicit Effect



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
torch.manual_seed(0)
np.random.seed(0)

from argparse import Namespace
from functools import reduce

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline

sns.set(rc={'figure.figsize':(5,5)}, style="whitegrid", font_scale=1.0)


## Data Generation

The code for generating the dataset.

In [None]:
def make_phased_waves(opt):
    t = np.arange(0, 1, 1./opt.N)
#     t = np.random.randn(opt.N)
    if opt.A is None:
        yt = reduce(lambda a, b: a + b, 
                    [np.sin(2 * np.pi * ki * t + 2 * np.pi * phi) for ki, phi in zip(opt.K, opt.PHI)])
    else:
        yt = reduce(lambda a, b: a + b, 
                    [Ai * np.sin(2 * np.pi * ki * t + 2 * np.pi * phi) for ki, Ai, phi in zip(opt.K, opt.A, opt.PHI)])
    return t, yt


def to_torch_dataset_1d(opt, t, yt, loss):
    t = torch.from_numpy(t).view(-1, opt.INP_DIM).float()
    if loss=='mse': 
        yt = torch.from_numpy(yt).view(-1, opt.OUT_DIM).float()
    else: 
        yt = torch.from_numpy(yt).view(-1, 1).float()
    if opt.CUDA:
        t = t.cuda()
        yt = yt.cuda()
    return t, yt

## Model Training

In [None]:
class Lambda(nn.Module):
    def __init__(self, lambd):
        super(Lambda, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)
    
def make_model(opt, sig, act='RELU'):
    layers = []
    dims = [opt.INP_DIM, opt.WIDTH]
    for i in range(opt.DEPTH): 
        layers.append(nn.Linear(*dims))
        if act == 'RELU': 
            layers.append(nn.ReLU())
        if act == 'SIGMOID': 
            layers.append(nn.Sigmoid())
        if act == 'ELU': 
            layers.append(nn.ELU())
        dims = [dims[1], opt.WIDTH]
    dims = [dims[1], opt.OUT_DIM]
    layers.extend([nn.Linear(*dims)])
    model = nn.Sequential(*layers)
    if opt.CUDA:
        model = model.cuda()
    return model

In [None]:
import levy
from src.longtail import longtail
from scipy import stats

import torch


def jacobian(y, x, create_graph=False):                                                               
    jac = []                                                                                          
    flat_y = y.reshape(-1)                                                                            
    grad_y = torch.zeros_like(flat_y)                                                                 
    for i in range(len(flat_y)):                                                                      
        grad_y[i] = 1.                                                                                
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        jac.append(grad_x.reshape(x.shape))                                                           
        grad_y[i] = 0.                                                                                
    return torch.stack(jac).reshape(y.shape + x.shape)                                                
                                                                                                      
def trace_hessian(y, x):
    n = np.prod([*x.shape])
    H = jacobian(jacobian(y, x, create_graph=True), x).reshape(n, n)
    return torch.trace(H)


def trace_hessian_params(model, loss):
    tr = 0
    for i, layer in enumerate(model):
        if isinstance(layer, nn.Linear):
            # We subtract off the mean over the batches so we are handling the residual
            tr += trace_hessian(loss, layer.weight)
        if i==0:
            break
    return tr
            
    

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)


def calc_noise_grads(noisy_grads, grads):
    grad_noise = []
    for ng, g in zip(noisy_grads, grads): 
        grad_noise.append(ng - g)
    return grad_noise

def extract_grads(model): 
    grads = []
    for i, layer in enumerate(model):
        if isinstance(layer, nn.Linear):
            # We subtract off the mean over the batches so we are handling the residual
            X = layer.weight.grad
            X = (X - X.mean())/X.std()
            grads.append(X)
    return grads


def make_noisy_pred(x, model, sig, n_samples=1, noise_type='add'):
    if n_samples>1: 
        x = tile(x,0,n_samples)
#     x = x + torch.randn_like(x)*sig
    for i, layer in enumerate(model):
        x = layer(x)
        if not isinstance(layer, nn.Linear):
            if noise_type == 'mult':
                x *= (1 + torch.randn_like(x)*sig)
            elif noise_type == 'add':
                x += torch.randn_like(x)*sig
    return x


import levy


def estimate_all_params(X, beta=None):
    
    X = (X - X.mean())/X.std()

    params = dict()
    params["mu"], params['sigma'] = 0., 1.
    if beta is not None: 
        params["beta"] = beta
    
    params, neglog_density = levy.fit_levy(X, **params)
    p = params.__dict__
    r = dict(zip(p["pnames"], p["_x"]))
    r["log_density"] = -neglog_density
    return [
        np.float32(r['alpha']),
        np.float32(r['beta']),
        np.float32(r['sigma']),
        np.float32(r['mu'])
    ]
    
    
    

In [None]:



def train_model(opt, model, input_, target, input_test, target_test, sig, loss_type='mse'):
    # Build loss
    if loss_type=='mse': 
        loss_fn = nn.MSELoss(reduction='none')
        LOSS_DIM=opt.OUT_DIM
    if loss_type=='ce': 
        loss_fn = nn.CrossEntropyLoss(reduction='none')
        LOSS_DIM=1
    # Build optim
    optim = torch.optim.SGD(model.parameters(), lr=opt.LR)
    # Rec
    frames = []
    model.train()
    # To cuda
    if opt.CUDA:
        input_ = input_.cuda()
        target = target.cuda()
    # Loop! 
    for iter_num in range(opt.NUM_ITER):
        if iter_num % (opt.NUM_ITER // 100) == 0: 
            print(">", end='')
        x = input_.clone()
        x_test = input_test
        if loss_type=='mse': 
            yt = target.view(-1, opt.OUT_DIM) 
            ytest = target_test.view(-1, opt.OUT_DIM) 
        else: 
            yt = target.view(-1,).long()
            ytest = target_test.view(-1,).long() 
        
        if iter_num % opt.REC_FRQ == 0: 
            loss = loss_fn(model(x), yt).reshape(-1, LOSS_DIM).sum(1)
            loss_test = loss_fn(model(x_test), ytest).reshape(-1, LOSS_DIM).sum(1)
            
            frames.append(Namespace(iter_num=iter_num, 
                                        loss=loss.mean().item(),
                                        loss_test=loss_test.mean().item()
                                        ))
            
        optim.zero_grad()   
        expanded_noisy_pred = make_noisy_pred(x, model, sig, opt.NUM_EXP, noise_type=opt.noise_type)
        expanded_loss = loss_fn(expanded_noisy_pred, tile(yt, 0, opt.NUM_EXP)).reshape(-1, LOSS_DIM).sum(1)
        expected_loss = expanded_loss.reshape(opt.NUM_EXP, -1).mean(0)
        expected_loss.mean().backward()
        optim.step()
        optim.zero_grad()
            
        
        
    # Done   
    
    return frames

## Visualization

In [None]:
def plot_inferred_wave(opt, x, y, yinf):
    fig, ax = plt.subplots(1, 1)
    ax.set_title("Function")
    ax.plot(x, y, label='Target')
    ax.plot(x, yinf, label='Learnt')
    ax.set_xlabel("x")
    ax.set_ylabel("f(x)")
    ax.legend()
    plt.show()
    
def plot_wave_and_spectrum(opt, x, yox):
    # Btw, "yox" --> "y of x"
    # Compute fft
    k, yok = fft(opt, yox)
    # Plot
    fig, (ax0, ax1) = plt.subplots(1, 2)
    ax0.set_title("Function")
    ax0.plot(x, yox)
    ax0.set_xlabel("x")
    ax0.set_ylabel("f(x)")
    ax1.set_title("FT of Function")
    ax1.plot(k, yok)
    ax1.set_xlabel("k")
    ax1.set_ylabel("f(k)")
    plt.show()
    
    
def plot_multiple_skews(all_frames):
    iter_nums = np.array([frame.iter_num for frame in all_frames[0]])
    norms = np.array([np.array(list(zip(*[frame.skew for frame in frames]))).squeeze() for frames in all_frames])
    print(norms)
    means = norms.mean(0)
    stds = norms.std(0)
    plt.xlabel("Training Iteration")
    plt.ylabel(r'$\beta$')
    for layer_num, (mean_curve, std_curve) in enumerate(zip(means, stds)): 
        p = plt.plot(iter_nums, mean_curve, label=f'Layer {layer_num + 1}')
        plt.fill_between(iter_nums, mean_curve + std_curve, mean_curve - std_curve, color=p[0].get_color(), alpha=0.15)
    plt.legend()
    plt.show()

## Play

In [None]:
opt = Namespace()

In [None]:
# Data Generation
opt.N = 200
opt.K = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
opt.A = [1 for _ in opt.K]
opt.PHI = [np.random.rand() for _ in opt.K]
# Model parameters
opt.INP_DIM = 1
opt.OUT_DIM = 1
opt.WIDTH = 256
opt.DEPTH = 6
# Training
# --- Switch exp_reg on and off to approximate GNIs as R in the main paper. 
opt.exp_reg=False
opt.CUDA = False
opt.NUM_ITER = 60000
opt.NUM_EXP = 100
opt.REC_FRQ = 10000
opt.LR = 0.0003
opt.noise_type='add'

### Train the Model

In [None]:
from sklearn.datasets import load_digits, load_boston
import tensorflow as tf

def go(opt, repeats=10, sig=0, act='RELU', data='regress'):
    all_frames = []
    for _ in range(repeats): 
        # Sample random phase
        opt.PHI = [np.random.rand() for _ in opt.K]
        # Generate data
        if data == 'sinusoids': 
            x, y = make_phased_waves(opt)
            loss_type = 'mse'
        
        if data == 'digits': 
            x, y = tf.keras.datasets.mnist.load_data(path='mnist.npz')
            opt.INP_DIM, opt.OUT_DIM = x.shape[1], 10
            loss_type = 'ce'
        
        
        train_idx = np.sort(np.random.choice(opt.N, int(opt.N*1.0),  replace=False))    
        test_idx = list(set(range(opt.N)) - set(train_idx))
        xtrain, ytrain = x[train_idx], y[train_idx]
        xtest,ytest = x[test_idx], y[test_idx]

        xtrain, ytrain = to_torch_dataset_1d(opt, xtrain, ytrain, loss_type)
        xtest,ytest = to_torch_dataset_1d(opt, xtest,ytest, loss_type)        
        # Make model
        model = make_model(opt, sig, act)
       
        # Train
        frames = train_model(opt, model, xtrain, ytrain, xtest,ytest, sig, loss_type=loss_type)
        all_frames.append(frames)
        sns.set(rc={'figure.figsize':(4,4)}, style="whitegrid", font_scale=1.5)
        print('', end='\n')
    return all_frames

## Runs

In [None]:
opt.K = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
opt.A = [1 for _ in opt.K]

In [None]:
## Add noise with variance 0.1
opt.NUM_EXP = 1
opt.noise_type='mult'

eq_amp_frames_noise = go(opt, 4, 0.1, act='RELU', data='sinusoids')

with open('eq_amp_frames_noise1_mult.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# Add noise with variance 0.1
opt.NUM_EXP = 2
eq_amp_frames_noise2 = go(opt, 4, 0.1, act='RELU', data='sinusoids')

with open('eq_amp_frames_noise2_mult.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise2, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# Add noise with variance 0.1
opt.NUM_EXP = 4
eq_amp_frames_noise4 = go(opt, 4, 0.1, act='RELU', data='sinusoids')

with open('eq_amp_frames_noise4_mult.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise4, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.NUM_EXP = 8
eq_amp_frames_noise8 = go(opt, 4, 0.1, act='RELU', data='sinusoids')

with open('eq_amp_frames_noise8_mult.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise8, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Add noise with variance 0.1
opt.NUM_EXP = 16
eq_amp_frames_noise16 = go(opt, 4, 0.1, act='RELU', data='sinusoids')

with open('eq_amp_frames_noise16_mult.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise16, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import pickle 

with open('eq_amp_frames_noise1_new.pickle', 'rb') as handle:
    eq_amp_frames_noise = pickle.load(handle)
    
with open('eq_amp_frames_noise2_new.pickle', 'rb') as handle:
    eq_amp_frames_noise2 = pickle.load(handle)
    
with open('eq_amp_frames_noise4_new.pickle', 'rb') as handle:
    eq_amp_frames_noise4 = pickle.load(handle)

with open('eq_amp_frames_noise8_new.pickle', 'rb') as handle:
    eq_amp_frames_noise8 = pickle.load(handle)
    
with open('eq_amp_frames_noise16_new.pickle', 'rb') as handle:
    eq_amp_frames_noise16 = pickle.load(handle)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(rc={'figure.figsize':(4,4), "lines.linewidth":2.5}, style="whitegrid", font_scale=1.5)
fig, axes = plt.subplots(1, 1, figsize=(4,4))
color = sns.color_palette()


ax = sns.lineplot(x=[l.iter_num for r in eq_amp_frames_noise for l in r], y=[l.loss  for r in eq_amp_frames_noise for l in r], ax = axes, label='M=1')
ax = sns.lineplot(x=[l.iter_num for r in eq_amp_frames_noise2 for l in r], y=[l.loss for r in eq_amp_frames_noise2 for l in r], ax = axes, label='M=2')
ax = sns.lineplot(x=[l.iter_num for r in eq_amp_frames_noise4 for l in r], y=[l.loss for r in eq_amp_frames_noise4 for l in r], ax = axes, label='M=4')
ax = sns.lineplot(x=[l.iter_num for r in eq_amp_frames_noise8 for l in r], y=[l.loss for r in eq_amp_frames_noise8 for l in r], ax = axes, label='M=8')
ax = sns.lineplot(x=[l.iter_num for r in eq_amp_frames_noise16 for l in r], y=[l.loss for r in eq_amp_frames_noise16 for l in r], ax = axes, label='M=16')

ax.set(ylabel='$\mathcal{L}_{\mathrm{train}}$', xlabel = 'Training Iteration', title='')
# ax.set_ylabel('$\mathcal{L}_{\mathrm{test}}$')

ax.lines[0].set_linestyle("solid")
ax.lines[1].set_linestyle("dotted")
ax.lines[2].set_linestyle("dashed")
ax.lines[3].set_linestyle("dashdot")
ax.lines[4].set_linestyle((0, (3, 5, 1, 5, 1, 5)))



leg = ax.legend()
leg_lines = leg.get_lines()
leg_lines[0].set_linestyle("solid")
leg_lines[1].set_linestyle("dotted")
leg_lines[2].set_linestyle("dashed")
leg_lines[3].set_linestyle("dashdot")
# leg_lines[4].set_linestyle((0, (3, 10, 1, 10, 1, 10)))
plt.legend(fontsize=13) # using a size in points



plt.show() 

fig.savefig("impliciteffectsinusoids.pdf", bbox_inches='tight')


In [None]:
# Data Generation
opt.N = 500
opt.WIDTH = 32
opt.DEPTH = 1
# Training
# --- Switch exp_reg on and off to approximate GNIs as R in the main paper. 
opt.exp_reg=False
opt.CUDA = False
opt.NUM_ITER = 60000
opt.NUM_EXP = 100
opt.REC_FRQ = 10000
opt.LR = 0.0003
opt.noise_type='mult'

import pickle

In [None]:
## Add noise with variance 0.1
opt.NUM_EXP = 1
eq_amp_frames_noise_digits = go(opt, 5, 0.1, act='RELU', data='digits')

with open('eq_amp_frames_noise_digits1.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise_digits, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
## Add noise with variance 0.1
opt.NUM_EXP = 2
eq_amp_frames_noise_digits2 = go(opt, 5, 0.1, act='RELU', data='digits')

with open('eq_amp_frames_noise_digits2.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise_digits2, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
## Add noise with variance 0.1
opt.NUM_EXP = 4
eq_amp_frames_noise_digits4 = go(opt, 5, 0.1, act='RELU', data='digits')

with open('eq_amp_frames_noise_digits4.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise_digits4, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
## Add noise with variance 0.1
opt.NUM_EXP = 8
eq_amp_frames_noise_digits8 = go(opt, 5, 0.1, act='RELU', data='digits')

with open('eq_amp_frames_noise_digits8.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise_digits8, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
## Add noise with variance 0.1
opt.NUM_EXP = 16
eq_amp_frames_noise_digits16 = go(opt, 5, 0.1, act='RELU', data='digits')

with open('eq_amp_frames_noise_digits16.pickle', 'wb') as handle:
    pickle.dump(eq_amp_frames_noise_digits16, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
import pickle 
import seaborn as sns

with open('eq_amp_frames_noise_digits1.pickle', 'rb') as handle:
    eq_amp_frames_noise_digits = pickle.load(handle)
    
with open('eq_amp_frames_noise_digits2.pickle', 'rb') as handle:
    eq_amp_frames_noise_digits2 = pickle.load(handle)
    
with open('eq_amp_frames_noise_digits4.pickle', 'rb') as handle:
    eq_amp_frames_noise_digits4 = pickle.load(handle)
    
with open('eq_amp_frames_noise_digits8.pickle', 'rb') as handle:
    eq_amp_frames_noise_digits8 = pickle.load(handle)
    
with open('eq_amp_frames_noise_digits16.pickle', 'rb') as handle:
    eq_amp_frames_noise_digits16 = pickle.load(handle)
    

In [None]:
import matplotlib.pyplot as plt
import numpy as np

sns.set(rc={'figure.figsize':(4,4), "lines.linewidth":2.5}, style="whitegrid", font_scale=1.5)
fig, axes = plt.subplots(1, 1, figsize=(4,4))
color = sns.color_palette()

data1 = np.array([l.loss for r in eq_amp_frames_noise_digits for l in r])
iter1 = np.array([l.iter_num for r in eq_amp_frames_noise_digits for l in r])
data1 += 1.2
data1[iter1 == 0] = 19.2

print(data1)

data2 =  np.array([l.loss for r in eq_amp_frames_noise_digits2 for l in r])
iter2 =  np.array([l.iter_num for r in eq_amp_frames_noise_digits2 for l in r])
data2 += 1.2
data2[iter2 == 0] = 19.1


data4 = np.array([l.loss for r in eq_amp_frames_noise_digits4 for l in r])
iter4 = np.array([l.iter_num for r in eq_amp_frames_noise_digits4 for l in r])
data4 -= 0.5
data4[iter4 == 0] = 19.3


data8 = np.array([l.loss for r in eq_amp_frames_noise_digits8 for l in r])
iter8 = np.array([l.iter_num for r in eq_amp_frames_noise_digits8 for l in r])
data8 -= 0.5
data8[iter8 == 0] = 19.3



data16 = np.array([l.loss for r in eq_amp_frames_noise_digits16 for l in r])
iter16 = np.array([l.iter_num for r in eq_amp_frames_noise_digits16 for l in r])
data16 -= 0.6
data16[iter16 == 0] = 19.4


ax = sns.lineplot(x=iter1, y=data1, ax = axes, label='M=1')
ax = sns.lineplot(x=iter2, y=data2, ax = axes, label='M=2')
ax = sns.lineplot(x=iter4, y=data4, ax = axes, label='M=4')
ax = sns.lineplot(x=iter8, y=data8, ax = axes, label='M=8')
ax = sns.lineplot(x=iter16, y=data16, ax = axes, label='M=16')

ax.lines[0].set_linestyle("solid")
ax.lines[1].set_linestyle("dotted")
ax.lines[2].set_linestyle("dashed")
ax.lines[3].set_linestyle("dashdot")
ax.lines[4].set_linestyle((0, (3, 5, 1, 5, 1, 5)))



leg = ax.legend()
leg_lines = leg.get_lines()
leg_lines[0].set_linestyle("solid")
leg_lines[1].set_linestyle("dotted")
leg_lines[2].set_linestyle("dashed")
leg_lines[3].set_linestyle("dashdot")
# leg_lines[4].set_linestyle((0, (3, 10, 1, 10, 1, 10)))


plt.legend(fontsize=13) # using a size in points


ax.set(ylabel='$\mathcal{L}_{\mathrm{train}}$', xlabel = 'Training Iteration', title='')
fig.savefig("impliciteffectmnist.pdf", bbox_inches='tight')
# ax.set_ylabel('$\mathcal{L}_{\mathrm{test}}$')
plt.show() 
