In [1]:
!pip install --upgrade --no-deps --force-reinstall -q git+https://github.com/Pehlevan-Group/finite-width-bayesian
!pip install neural_tangents

  Building wheel for finite-width-bayesian (setup.py) ... [?25l[?25hdone
Collecting neural_tangents
  Downloading neural_tangents-0.3.8-py2.py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 5.3 MB/s 
Collecting frozendict>=1.2
  Downloading frozendict-2.0.7-py3-none-any.whl (8.3 kB)
Installing collected packages: frozendict, neural-tangents
Successfully installed frozendict-2.0.7 neural-tangents-0.3.8


In [2]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

import neural_tangents as nt
from neural_tangents import stax

from langevin import model
from langevin.utils import convert_nt, curr_time
import langevin.theory as theory
import langevin.optimizer as opt
import langevin.dataset as ds

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, grad, vmap
from jax.config import config
config.update("jax_enable_x64", True)
key = random.PRNGKey(1)

from functools import partial
from skimage.transform import resize

import pytz
from datetime import datetime
from dateutil.relativedelta import relativedelta

def time_diff(t_start):
    t_end = datetime.now(pytz.timezone('US/Eastern'))
    t_diff = relativedelta(t_end, t_start)  # later/end time comes first!
    return '{h}h {m}m {s}s'.format(h=t_diff.hours, m=t_diff.minutes, s=t_diff.seconds)

In [3]:
model_type = 'fnn'
opt_mode = 'sgld'
nonlin = 'linear'
dataset_name = 'mnist'
resized = 10 ## Resize the images to 10 x 10 pixels

N_tr = 1000
x_train, y_train = ds.dataset(N_tr, dataset_name, model_type, resized)
print(x_train.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(1000, 100)


In [None]:
## For bottleneck experiments 

hidden_widths = [[100,100,100],[200,200,200],[300,300,300],[400,400,400],[500,500,500],[600,600,600]]

beta = 1
batch_size = N_tr
step_size = min(1/N_tr, 1e-4)
batch_factor = N_tr//batch_size

nT = 5000000
burn_in = nT//3

K_avgs = []
K_nngps = []
Kernel_Fns = []

## Compute the theory
K_theories = []
for hidden_width in hidden_widths:
    print(model_type, ' | ', hidden_width)

    ## Create the model layers
    layers, layers_ker = model.model(hidden_width, nonlin=nonlin, model_type=model_type)

    ## Create the model functions for each layer
    layer_fns = []
    kernel_fns = []
    emp_kernel_fns = []
    for i, layer in enumerate(layers):
        init_fn, apply_fn, kernel_fn = stax.serial(*(layers[:i+1]))
        layer_fns += [jit(apply_fn)]
        kernel_fns += [jit(kernel_fn)]
        emp_kernel_fns += [jit(partial(nt.empirical_nngp_fn(layer_fns[i]), x_train, None))]
    init_fn, apply_fn, kernel_fn = stax.serial(*layers)
    apply_fn = jit(apply_fn)
    kernel_fn = jit(kernel_fn)
    
    ## Initialize the model
    _, params = init_fn(key, input_shape=x_train.shape)

    ## Set Optimizer
    opt_init, opt_update, get_params = opt.sgld(step_size, beta, batch_factor)
    opt_state = opt_init(params)
    
    ## Set Loss Function and its grad
    loss_fn = jit(lambda params: jnp.sum((apply_fn(params,x_train)-y_train)**2)/2)
    g_loss = jit(grad(loss_fn))

    avg_count = 0
    K_avg = []
    t_start = datetime.now(pytz.timezone('US/Eastern'))
    for j in range(nT):
        _,key = random.split(key)
        opt_params = get_params(opt_state)
        opt_state = opt_update(i, g_loss(opt_params), opt_state)

        if j > burn_in:
            avg_count += 1
            for i, lay_idx in enumerate(layers_ker):
                params = opt_params[:lay_idx+1]
                if j == burn_in + 1:
                    #K_avg += [nt.empirical_nngp_fn(layer_fns[i])(x_train,None,params)]
                    K_avg += [emp_kernel_fns[lay_idx](params)]
                else:
                    #K_avg[i] += nt.empirical_nngp_fn(layer_fns[i])(x_train,None,params)
                    K_avg[i] += emp_kernel_fns[lay_idx](params)

        if j % 1000 == 0:
            print('%d | loss: %f | avg_count: %d | time: %s'%(j, loss_fn(opt_params), avg_count, time_diff(t_start)), flush=True)
            
    kernel_fns_relu = []        
    K_nngp =  []
    for lay_idx in layers_ker:
        kernel_fns_relu += [kernel_fns[lay_idx]]
        K_nngp += [kernel_fns[lay_idx](x_train,).nngp]
    
    K_avgs += [K_avg]
    K_nngps += [K_nngp]
    
    ## Compute the theory predictions
    _, K_theory, Gxx, Gyy = theory.theory_linear(x_train, y_train, beta, kernel_fns, hidden_width)
    K_theories += [K_theory]
        
    with open('data_%s_%d_%s_%s_%s_nT_%d.pkl'%(str(hidden_width), N_tr, model_type, opt_mode, nonlin, nT), 'wb') as outfile:
        pickle.dump({'K_avg': K_avg, 'K_nngp': K_nngp, 'K_theory': K_theory, 'burn_in': burn_in, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)
                 

    plt.scatter((K_avg[0]/avg_count-Gxx).reshape(-1)[:], (K_theory[0]-Gxx).reshape(-1)[:], label='Width: %d'%hidden_width[0])
    plt.savefig('k-nngp_%s_fnn_%s.jpg'%(str(hidden_width), opt_mode))
    plt.show()
    
    plt.scatter((K_avg[0]/avg_count).reshape(-1)[:], (K_theory[0]).reshape(-1)[:], label='Width: %d'%hidden_width[0])
    plt.savefig('k_vs_nngp_%s_fnn_%s.jpg'%(str(hidden_width), opt_mode))
    plt.show()

        
with open('data_%d_%s_%s.pkl'%(N_tr, model_type, opt_mode), 'wb') as outfile:
    pickle.dump({'K_avgs': K_avgs, 'K_nngps': K_nngps, 'K_theories': K_theories, 'nonlin': nonlin,
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)


In [None]:
depths = np.arange(len(K_avgs[0]))
deviation = []
deviation_th = []

for i, hidden_width in enumerate(hidden_widths):
    width = hidden_width[0]
    K_exp = K_avgs[i]
    K_nngp = K_nngps[i]
    deviation_width = []
    for j, K in enumerate(K_exp):
        K_th = K_nngp[j]
        deviation_width += [np.linalg.norm((K/avg_count - K_th))**2]
    deviation += [deviation_width]
    
    K_theory = K_theories[i]
    deviation_th += [[np.linalg.norm(K - K_t)**2 for K, K_t in zip(K_theory, K_nngp)]]

deviation = np.array(deviation)
print(deviation.shape)

plt.figure(figsize=(6,5))
for l in range(2):
    plt.loglog([width[0] for width in hidden_widths], deviation[:,l], 'o', label='layer %d'%(l+1))

if model_type == 'fnn':
    deviation_th = np.array(deviation_th)
    for l in range(2):
        if l == 0:
            plt.loglog([width[0] for width in hidden_widths], deviation_th[:,l],'k--', label='theory')
        else:
            plt.loglog([width[0] for width in hidden_widths], deviation_th[:,l],'k--')


plt.gca().tick_params(axis='both', which = 'major', labelsize=14)
plt.gca().tick_params(axis='both', which = 'minor', labelsize=12)
plt.legend(fontsize=16)    
plt.xlabel(r'Width', fontsize=20)
plt.ylabel(r'$||K_{exp} - K_{GP}||_F$', fontsize=20)
plt.tight_layout()
# plt.gca().set_aspect(0.28)
plt.savefig('one_over_width_%s_sgld.png'%model_type, dpi=600)

In [None]:
data_limit = 2000

for j in range(len(hidden_widths[0])):
    plt.figure(figsize=(6,5))
    for i, hidden_width in enumerate(hidden_widths):
        k_exp = (K_avgs[i][j]/avg_count-Gxx).reshape(-1)[:data_limit]
        k_th  = (K_theories[i][j]-Gxx).reshape(-1)[:data_limit]
        
        lin = np.linspace(min(k_th),max(k_th),1000)
        print(k_th.shape)
        plt.scatter(k_exp, k_th, label='Width: %d'%hidden_width[j])
        if i == 0:
            plt.plot(lin,lin,'k--')
            plt.xlim([min(lin)*1.2,max(lin)*1.05])
            plt.ylim([min(lin)*1.2,max(lin)*1.05])
        plt.legend(fontsize=16)   
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel('$K_{exp}^{(%d)} - K_{GP}^{(%d)}$'%(j + 1,j + 1), fontsize=20)
    plt.ylabel('$K_{th}^{(%d)} - K_{GP}^{(%d)}$'%(j + 1,j + 1), fontsize=20)
    plt.tight_layout()
    plt.savefig('k-nngp_fnn_sgld_layer_%d.png'%(j+1), dpi=600)
    plt.show()
