In [1]:
!pip install --upgrade --no-deps --force-reinstall -q git+https://github.com/Pehlevan-Group/finite-width-bayesian
!pip install neural_tangents

  Building wheel for finite-width-bayesian (setup.py) ... [?25l[?25hdone


In [2]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

import neural_tangents as nt
from neural_tangents import stax

from langevin import model
from langevin.utils import convert_nt, curr_time
import langevin.theory as theory
import langevin.optimizer as opt
import langevin.dataset as ds

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, grad, vmap
from jax.config import config
config.update("jax_enable_x64", True)
key = random.PRNGKey(1)

from functools import partial
from skimage.transform import resize

import pytz
from datetime import datetime
from dateutil.relativedelta import relativedelta

def time_diff(t_start):
    t_end = datetime.now(pytz.timezone('US/Eastern'))
    t_diff = relativedelta(t_end, t_start)  # later/end time comes first!
    return '{h}h {m}m {s}s'.format(h=t_diff.hours, m=t_diff.minutes, s=t_diff.seconds)

In [3]:
model_type = 'fnn'
opt_mode = 'sgld'
nonlin = 'relu'
dataset_name = 'mnist'
resized = 10 ## Resize the images to 10 x 10 pixels

N_tr = 1000
x_train, y_train = ds.dataset(N_tr, dataset_name, model_type, resized)
print(x_train.shape)

(1000, 100)


In [None]:
## For bottleneck experiments 

no_bottleneck_widths = [[100,100,100],[200,200,200],[300,300,300],[400,400,400],[500,500,500],[600,600,600]]
bottleneck_widths = [[100,50,100],[200,50,200],[300,50,300],[400,50,400],[500,50,500],[600,50,600]]

exp_type = 0 # set it to 1 for bottleneck experiments

if exp_type == 0:
  hidden_widths = no_bottleneck_widths
else:
  hidden_widths = bottleneck_widths


beta = 1
batch_size = N_tr
step_size = min(1/N_tr, 1e-4)
batch_factor = N_tr//batch_size

nT = 5000000
burn_in = nT//3

K_avgs = []
K_nngps = []
Kernel_Fns = []

## Compute the theory
K_theories = []
for hidden_width in hidden_widths:
    print(model_type, ' | ', hidden_width)

    ## Create the model layers
    layers, layers_ker = model.model(hidden_width, nonlin=nonlin, model_type=model_type)

    ## Create the model functions for each layer
    layer_fns = []
    kernel_fns = []
    emp_kernel_fns = []
    for i, layer in enumerate(layers):
        init_fn, apply_fn, kernel_fn = stax.serial(*(layers[:i+1]))
        layer_fns += [jit(apply_fn)]
        kernel_fns += [jit(kernel_fn)]
        emp_kernel_fns += [jit(partial(nt.empirical_nngp_fn(layer_fns[i]), x_train, None))]
    init_fn, apply_fn, kernel_fn = stax.serial(*layers)
    apply_fn = jit(apply_fn)
    kernel_fn = jit(kernel_fn)
    
    ## Initialize the model
    _, params = init_fn(key, input_shape=x_train.shape)

    ## Set Optimizer
    opt_init, opt_update, get_params = opt.sgld(step_size, beta, batch_factor)
    opt_state = opt_init(params)
    
    ## Set Loss Function and its grad
    loss_fn = jit(lambda params: jnp.sum((apply_fn(params,x_train)-y_train)**2)/2)
    g_loss = jit(grad(loss_fn))

    avg_count = 0
    K_avg = []
    t_start = datetime.now(pytz.timezone('US/Eastern'))
    for j in range(nT):
        _,key = random.split(key)
        opt_params = get_params(opt_state)
        opt_state = opt_update(i, g_loss(opt_params), opt_state)

        if j > burn_in:
            avg_count += 1
            for i, lay_idx in enumerate(layers_ker):
                params = opt_params[:lay_idx+1]
                if j == burn_in + 1:
                    #K_avg += [nt.empirical_nngp_fn(layer_fns[i])(x_train,None,params)]
                    K_avg += [emp_kernel_fns[lay_idx](params)]
                else:
                    #K_avg[i] += nt.empirical_nngp_fn(layer_fns[i])(x_train,None,params)
                    K_avg[i] += emp_kernel_fns[lay_idx](params)

        if j % 1000 == 0:
            print('%d | loss: %f | avg_count: %d | time: %s'%(j, loss_fn(opt_params), avg_count, time_diff(t_start)), flush=True)
            
    kernel_fns_relu = []        
    K_nngp =  []
    for lay_idx in layers_ker:
        kernel_fns_relu += [kernel_fns[lay_idx]]
        K_nngp += [kernel_fns[lay_idx](x_train,).nngp]
    
    K_avgs += [K_avg]
    K_nngps += [K_nngp]
    
    ## Compute the theory predictions
    if model_type == 'fnn':
        _, K_theory, Gxx, Gyy = theory.theory_linear(x_train, y_train, beta, kernel_fns, hidden_width)
        K_theories += [K_theory]
        
    with open('data_%s_%d_%s_%s_%s_nT_%d.pkl'%(str(hidden_width), N_tr, model_type, opt_mode, nonlin, nT), 'wb') as outfile:
        pickle.dump({'K_avg': K_avg, 'K_nngp': K_nngp, 'K_theory': K_theory, 'burn_in': burn_in, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)
                 

    plt.scatter((K_avg[0]/avg_count-Gxx).reshape(-1)[:], (K_theory[0]-Gxx).reshape(-1)[:], label='Width: %d'%hidden_width[0])
    plt.savefig('k-nngp_%s_fnn_%s.jpg'%(str(hidden_width), opt_mode))
    plt.show()
    
    plt.scatter((K_avg[0]/avg_count).reshape(-1)[:], (K_theory[0]).reshape(-1)[:], label='Width: %d'%hidden_width[0])
    plt.savefig('k_vs_nngp_%s_fnn_%s.jpg'%(str(hidden_width), opt_mode))
    plt.show()

        
with open('data_%d_%s_%s.pkl'%(N_tr, model_type, opt_mode), 'wb') as outfile:
    pickle.dump({'K_avgs': K_avgs, 'K_nngps': K_nngps, 'K_theories': K_theories, 'nonlin': nonlin,
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)

depths = jnp.arange(len(K_avgs[0]))
deviation = []
deviation_th = []

for i, hidden_width in enumerate(hidden_widths):
    width = hidden_width[0]
    K_exp = K_avgs[i]
    K_nngp = K_nngps[i]
    deviation += [[jnp.linalg.norm(K/avg_count - K_t)**2 for K, K_t in zip(K_exp, K_nngp)]]

    if model_type == 'fnn':
        K_theory = K_theories[i]
        deviation_th += [[jnp.linalg.norm(K - K_t)**2 for K, K_t in zip(K_theory, K_nngp)]]

deviation = np.array(deviation)
print(deviation.shape)
plt.loglog([width[0] for width in hidden_widths], deviation[:,:-1], 'o')

if model_type == 'fnn':
    deviation_th = np.array(deviation_th)
    plt.loglog([width[0] for width in hidden_widths], deviation_th,'k--')
    
plt.savefig('one_over_width_%s_%s.png'%(model_type, opt_mode))
plt.close()

for i, hidden_width in enumerate(hidden_widths):
    plt.scatter((K_avgs[i][0]/avg_count-Gxx).reshape(-1)[:], (K_theories[i][0]-Gxx).reshape(-1)[:], label='Width: %d'%hidden_width[0])
    #plt.legend()
    plt.savefig('k-nngp_fnn_%s.jpg'%opt_mode)
plt.close()

for i, hidden_width in enumerate(hidden_widths):
    plt.scatter((K_avgs[i][0]/avg_count).reshape(-1)[:], (K_theories[i][0]).reshape(-1)[:], label='Width: %d'%hidden_width[0])
    plt.savefig('k_vs_nngp_fnn_%s.jpg'%opt_mode)
plt.close()