In [None]:
!pip install --upgrade --no-deps --force-reinstall -q git+https://github.com/Pehlevan-Group/finite-width-bayesian
!pip install neural_tangents

In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

import neural_tangents as nt
from neural_tangents import stax

from langevin import model
from langevin.utils import convert_nt, curr_time
import langevin.theory as theory
import langevin.optimizer as opt
import langevin.dataset as ds

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, grad, vmap
from jax.config import config
config.update("jax_enable_x64", True)
key = random.PRNGKey(1)

from functools import partial
from skimage.transform import resize

import pytz
from datetime import datetime
from dateutil.relativedelta import relativedelta

def time_diff(t_start):
    t_end = datetime.now(pytz.timezone('US/Eastern'))
    t_diff = relativedelta(t_end, t_start)  # later/end time comes first!
    return '{h}h {m}m {s}s'.format(h=t_diff.hours, m=t_diff.minutes, s=t_diff.seconds)

In [65]:
dataset_name = 'mnist'
model_type = 'cnn1d'
opt_mode = 'sgld'
nonlin = 'linear'

N_tr = 50
resized = 5
x_train, y_train = ds.dataset(N_tr, dataset_name, model_type, resized);
print(x_train.shape)

hidden_widths = [[250,250], [400,400], [500,500], [600,600], [700,700], [750,750]]
beta = 1
batch_size = N_tr
step_size = 1/2000
batch_factor = N_tr//batch_size

nT = 2000
burn_in = nT//4

(50, 25, 1)


In [None]:
K_avgs = []
K_nngps = []
Kernel_Fns = []
## Compute the theory
K_theories = []
for hidden_width in hidden_widths:
    print(model_type, ' | ', hidden_width)

    ## Create the model layers
    layers, layers_ker = model.model(hidden_width, nonlin = nonlin, model_type = model_type)
    ## Create the model functions for each layer
    init_fn, apply_fn, kernel_fn, layer_fns, kernel_fns, emp_kernel_fns = model.network_fns(layers, x_train)
    ## Initialize the model
    _, params = init_fn(key, input_shape=x_train.shape)

    ## Set Optimizer
    opt_init, opt_update, get_params = opt.sgld(step_size, beta, batch_factor)
    opt_state = opt_init(params)
    
    ## Set Loss Function and its grad
    loss_fn = jit(lambda params: jnp.sum((apply_fn(params,x_train)-y_train)**2)/2)
    g_loss = jit(grad(loss_fn))

    avg_count = 0
    K_avg = []
    t_start = curr_time()
    for j in range(nT):
        opt_params = get_params(opt_state)
        opt_state = opt_update(j, g_loss(opt_params), opt_state)

        if j > burn_in:
            avg_count += 1
            for i, idx in enumerate(layers_ker):
                if j == burn_in + 1:
                    K_avg += [emp_kernel_fns[idx](opt_params[:idx+1])]
                else: 
                    K_avg[i] += emp_kernel_fns[idx](opt_params[:idx+1])

        if j % 1000 == 0:
            print('%d | loss: %f | avg_count: %d | time: %s'%(j, loss_fn(opt_params), avg_count, time_diff(t_start)), flush=True)
    
    K_nngp, K_theory, Gxx, Gyy = theory.theory_linear(x_train, y_train, beta, kernel_fns, hidden_width)
    K_nngps += [K_nngp]
    K_theories += [K_theory]

        
    with open('data_%s_%d_%s_%s_%s.pkl'%(str(hidden_width), N_tr, model_type, opt_mode, nonlin), 'wb') as outfile:
        pickle.dump({'K_avg': K_avg, 'K_nngp': K_nngp, 'K_theory': K_theory, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)
                 

    if model_type == 'fnn':
        plt.scatter((K_avg[0]/avg_count-Gxx).reshape(-1)[:], (K_theory[0]-Gxx).reshape(-1)[:], label='Width: %d'%hidden_width[0])
        plt.savefig('k-nngp_%s_fnn_%s_%s.jpg'%(str(hidden_width), opt_mode, nonlin))
        plt.close()
    
        plt.scatter((K_avg[0]/avg_count).reshape(-1)[:], (K_theory[0]).reshape(-1)[:], label='Width: %d'%hidden_width[0])
        plt.savefig('k_vs_nngp_%s_fnn_%s_%s.jpg'%(str(hidden_width), opt_mode, nonlin))
        plt.close()

        
with open('data_%d_%s_%s_%s.pkl'%(N_tr, model_type, opt_mode, nonlin), 'wb') as outfile:
    pickle.dump({'K_avgs': K_avgs, 'K_nngps': K_nngps, 'K_theories': K_theories, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)

    

In [None]:
K_avgs = []
K_nngps = []
avg_counts = []
nTs = []

nonlin = 'linear'
data_files = ['data_%s_50_cnn1d_sgld_%s.pkl'%(str(s),nonlin) for s in hidden_widths]

for data_file in data_files:
    with open(data_file, 'rb') as infile:
        data = pickle.load(infile)
        N_tr = data['N_tr']
        model_type= data['model_type']
        K_avgs += [data['K_avg']]
        K_nngps += [data['K_nngp']]
        avg_counts += [data['avg_count']]
        nTs += [data['nT']]
        beta = data['beta']

## Preprocess NT kernels
for i, K_width in zip(np.arange(len(K_avgs)), K_avgs):
    for j, K in enumerate(K_width):
        K_avgs[i][j] =  convert_nt(K)/avg_count

In [None]:
depths = jnp.arange(len(K_avgs[0]))
widths = []
deviation = []
deviation_th = []

for i, hidden_width in enumerate(hidden_widths):
    widths += [hidden_width[0]]
    K_exp = K_avgs[i]
    K_nngp = K_nngps[i]
    K_theory = K_theories[i]

    deviation += [[jnp.linalg.norm(K - K_t)**2 for K, K_t in zip(K_exp, K_nngp)]]
    deviation_th += [[jnp.linalg.norm(K - K_t)**2 for K, K_t in zip(K_theory, K_nngp)]]

widths = np.array(widths)
deviation = np.array(deviation)
deviation_th = np.array(deviation_th)
print(deviation.shape)
plt.loglog(widths, deviation[:,:-2], 'o')
plt.loglog(widths, deviation_th, '-')
# plt.loglog(widths, 1/widths, '--')
    
plt.savefig('one_over_width_%s_%s_%s.png'%(model_type, opt_mode, nonlin))
plt.show()

fig, axs = plt.subplots(1,2)

axs[0].imshow(K_avgs[0][-1]/avg_count)
axs[1].imshow(K_nngps[0][-1])
plt.show()

K_exp = [K/avg_count for K in K_avg[:-2]]
K_nngp_th = [K/avg_count for K in K_nngp[:-2]]

fig, axs = plt.subplots(1,2)
lay_idx = 1
img_idx1 = 40
img_idx2 = 40
axs[0].imshow(K_exp[lay_idx][img_idx1,img_idx2]/avg_count)
axs[1].imshow(K_exp[lay_idx][img_idx1,img_idx2]/avg_count - K_nngp_th[lay_idx][img_idx1,img_idx2])
plt.savefig('K_cov_%s_%s_%s.jpg'%(model_type, opt_mode, nonlin))
plt.show()