In [1]:
!pip install --upgrade --no-deps --force-reinstall -q git+https://github.com/Pehlevan-Group/finite-width-bayesian
!pip install neural_tangents

  Building wheel for finite-width-bayesian (setup.py) ... [?25l[?25hdone
Collecting neural_tangents
  Downloading neural_tangents-0.3.8-py2.py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 5.4 MB/s 
[?25hCollecting frozendict>=1.2
  Downloading frozendict-2.0.7-py3-none-any.whl (8.3 kB)
Installing collected packages: frozendict, neural-tangents
Successfully installed frozendict-2.0.7 neural-tangents-0.3.8


In [2]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

import neural_tangents as nt
from neural_tangents import stax

from langevin import model
from langevin.utils import convert_nt, curr_time
import langevin.theory as theory
import langevin.optimizer as opt
import langevin.dataset as ds

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, grad, vmap
from jax.config import config
config.update("jax_enable_x64", True)
key = random.PRNGKey(1)

from functools import partial
from skimage.transform import resize

import pytz
from datetime import datetime
from dateutil.relativedelta import relativedelta

def time_diff(t_start):
    t_end = datetime.now(pytz.timezone('US/Eastern'))
    t_diff = relativedelta(t_end, t_start)  # later/end time comes first!
    return '{h}h {m}m {s}s'.format(h=t_diff.hours, m=t_diff.minutes, s=t_diff.seconds)

In [3]:
dataset_name = 'mnist'
model_type = 'cnn1d'
opt_mode = 'sgld'
nonlin = 'linear'

N_tr = 50
resized = 5
x_train, y_train = ds.dataset(N_tr, dataset_name, model_type, resized);
print(x_train.shape)

hidden_widths = [[250,250], [400,400], [500,500], [600,600], [700,700], [750,750]]
beta = 1
batch_size = N_tr
step_size = 1/2000
batch_factor = N_tr//batch_size

## Set this to 2000000 to obtain accurate posterior mean
nT = 2000
burn_in = nT//4

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(50, 25, 1)


In [None]:
K_avgs = []
K_nngps = []

## Compute the theory
K_theories = []
for hidden_width in hidden_widths:
    print(model_type, ' | ', hidden_width)

    ## Create the model layers
    layers, layers_ker = model.model(hidden_width, nonlin = nonlin, model_type = model_type)
    ## Create the model functions for each layer
    init_fn, apply_fn, kernel_fn, layer_fns, kernel_fns, emp_kernel_fns = model.network_fns(layers, x_train)
    ## Initialize the model
    _, params = init_fn(key, input_shape=x_train.shape)

    ## Set Optimizer
    opt_init, opt_update, get_params = opt.sgld(step_size, beta, batch_factor)
    opt_state = opt_init(params)
    
    ## Set Loss Function and its grad
    loss_fn = jit(lambda params: jnp.sum((apply_fn(params,x_train)-y_train)**2)/2)
    g_loss = jit(grad(loss_fn))

    avg_count = 0
    K_avg = []
    t_start = curr_time()
    for j in range(nT):
        opt_params = get_params(opt_state)
        opt_state = opt_update(j, g_loss(opt_params), opt_state)

        if j > burn_in:
            avg_count += 1
            for i, idx in enumerate(layers_ker):
                if j == burn_in + 1:
                    K_avg += [emp_kernel_fns[idx](opt_params[:idx+1])]
                else: 
                    K_avg[i] += emp_kernel_fns[idx](opt_params[:idx+1])

        if j % 1000 == 0:
            print('%d | loss: %f | avg_count: %d | time: %s'%(j, loss_fn(opt_params), avg_count, time_diff(t_start)), flush=True)
    
    K_nngp, K_theory, Gxx, Gyy = theory.theory_linear(x_train, y_train, beta, kernel_fns, hidden_width)
    K_nngps += [K_nngp]
    K_theories += [K_theory]

        
    with open('data_%s_%d_%s_%s_%s.pkl'%(str(hidden_width), N_tr, model_type, opt_mode, nonlin), 'wb') as outfile:
        pickle.dump({'K_avg': K_avg, 'K_nngp': K_nngp, 'K_theory': K_theory, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)
                 

    if model_type == 'fnn':
        plt.scatter((K_avg[0]/avg_count-Gxx).reshape(-1)[:], (K_theory[0]-Gxx).reshape(-1)[:], label='Width: %d'%hidden_width[0])
        plt.savefig('k-nngp_%s_fnn_%s_%s.jpg'%(str(hidden_width), opt_mode, nonlin))
        plt.close()
    
        plt.scatter((K_avg[0]/avg_count).reshape(-1)[:], (K_theory[0]).reshape(-1)[:], label='Width: %d'%hidden_width[0])
        plt.savefig('k_vs_nngp_%s_fnn_%s_%s.jpg'%(str(hidden_width), opt_mode, nonlin))
        plt.close()

        
with open('data_%d_%s_%s_%s.pkl'%(N_tr, model_type, opt_mode, nonlin), 'wb') as outfile:
    pickle.dump({'K_avgs': K_avgs, 'K_nngps': K_nngps, 'K_theories': K_theories, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)

    

In [5]:
K_avgs = []
K_nngps = []
avg_counts = []
nTs = []

nonlin = 'linear'
data_files = ['data_%s_%d_%s_%s_%s.pkl'%(str(s), N_tr, model_type, opt_mode, nonlin) for s in hidden_widths]

for data_file in data_files:
    with open(data_file, 'rb') as infile:
        data = pickle.load(infile)
        N_tr = data['N_tr']
        model_type= data['model_type']
        K_avgs += [data['K_avg']]
        K_nngps += [data['K_nngp']]
        avg_counts += [data['avg_count']]
        nTs += [data['nT']]
        beta = data['beta']

## Preprocess NT kernels
for i, K_width in zip(np.arange(len(K_avgs)), K_avgs):
    for j, K in enumerate(K_width):
        K_avgs[i][j] =  convert_nt(K)/avg_count

In [6]:
depths = jnp.arange(len(K_avgs[0]))
deviation = []
deviation_th = []

for K_exp, K_nngp, K_theory, hidden_width in zip(K_avgs, K_nngps, K_theories, hidden_widths):

    dev_width = []
    dev_width_th = []
    for j, K, K_nn, K_th in zip(np.arange(len(K_exp)), K_exp, K_nngp, [*K_theory,0,0]):
        
        if len(K.shape) == 6:
            K = K.reshape(N_tr, N_tr, 100, 100)
            K_nn = K_nn.reshape(N_tr, N_tr, 100, 100)
            K_th = K_th.reshape(N_tr, N_tr, 100, 100)
            
        dev_width += [jnp.linalg.norm(K - K_nn)]
        dev_width_th += [jnp.linalg.norm((K_th - K_nn))]
        
        
    deviation += [dev_width]
    deviation_th += [dev_width_th]

deviation = np.array(deviation)
deviation_th = np.array(deviation_th)

In [8]:
lay_idx = 0
xs = []
widths = [width[0] for width in hidden_widths[:]]
pix_size = 6000

plt.figure(figsize=(6,5.5))
for width_idx in range(len(hidden_widths)):
    K_exp = [K.reshape(N_tr,N_tr,resized**2,resized**2) for K in K_avgs[width_idx][:2]]
    K_nngp_th = [K.reshape(N_tr,N_tr,resized**2,resized**2) for K in K_nngp[:2]]
    K_th = [K.reshape(N_tr,N_tr,resized**2,resized**2) for K in K_theories[width_idx]]

    x = (K_exp[lay_idx]-K_nngp_th[lay_idx]).reshape(-1)[:pix_size]
    y = (K_th[lay_idx]-K_nngp_th[lay_idx]).reshape(-1)[:pix_size]
    plt.scatter(x, y, label='width = %d'%widths[width_idx])
    if width_idx == 0:
        plt.plot(np.linspace(min(x), max(x), 1000), np.linspace(min(x), max(x),1000),'k--')
    print(np.mean(K_exp[lay_idx]/K_th[lay_idx]), np.mean(x/y))
plt.plot(x,x, 'k--')     
plt.legend(fontsize=12)
plt.gca().tick_params(axis='both', which = 'major', labelsize=14)
plt.gca().tick_params(axis='both', which = 'minor', labelsize=14)
plt.xlabel('$K_{exp}^{(%d)} - K_{GP}^{(%d)}$'%(lay_idx + 1,lay_idx + 1), fontsize=20)
plt.ylabel('$K_{th}^{(%d)} - K_{GP}^{(%d)}$'%(lay_idx + 1,lay_idx + 1), fontsize=20)
plt.tight_layout()
plt.savefig('k-nngp_cov_cnn1d_layer_%d.png'%(lay_idx+1), dpi=600)
plt.close()

0.9710410164304019 2.1609058940264463
0.9257817125812909 11.009963712609785
1.2664631569170828 3.35836258084833
1.07996555025582 5.873307340313838
0.9494100220237902 7.272438913285475
1.0288474503046459 -0.4219553860220873
