In [None]:
!pip install --upgrade --no-deps --force-reinstall -q git+https://github.com/Pehlevan-Group/finite-width-bayesian
!pip install neural_tangents

In [4]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

import neural_tangents as nt
from neural_tangents import stax
import langevin.optimizer as opt
import langevin.theory as theory
import langevin.model as model
import langevin.dataset as ds
from langevin.utils import convert_nt, curr_time

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, grad, vmap
from jax.config import config
config.update("jax_enable_x64", True)
key = random.PRNGKey(1)

from functools import partial
from skimage.transform import resize

import pytz
from datetime import datetime
from dateutil.relativedelta import relativedelta

def time_diff(t_start):
    t_end = datetime.now(pytz.timezone('US/Eastern'))
    t_diff = relativedelta(t_end, t_start)  # later/end time comes first!
    return '{h}h {m}m {s}s'.format(h=t_diff.hours, m=t_diff.minutes, s=t_diff.seconds)


In [None]:
dataset_name = 'mnist'
model_type = 'cnn'
opt_mode = 'sgld'
nonlin = 'linear'

N_tr = 50
resized = 10
x_train, y_train = ds.dataset(N_tr, dataset_name, model_type, resized);
    
    
hidden_widths = [[50,50],[100,100],[150,150],[200,200],[250,250],[300,300],[350,350],[400,400],[500,500],[600,600]]
beta = 1
batch_size = N_tr
step_size = 1/1000
batch_factor = N_tr//batch_size

## Set this to 3000000 to obtain an accurate posterior mean
nT = 3000000
burn_in = nT//3



In [None]:
K_avgs = []
K_nngps = []

## Compute the theory
K_theories = []
for hidden_width in hidden_widths:
    print(model_type, ' | ', hidden_width)

    ## Create the model layers
    layers, layers_ker = model.model(hidden_width, nonlin = nonlin, model_type = model_type)
    ## Create the model functions for each layer
    init_fn, apply_fn, kernel_fn, layer_fns, kernel_fns, emp_kernel_fns = model.network_fns(layers, x_train)
    ## Initialize the model
    _, params = init_fn(key, input_shape=x_train.shape)

    ## Set Optimizer
    opt_init, opt_update, get_params = opt.sgld(step_size, beta, batch_factor)
    opt_state = opt_init(params)
    
    ## Set Loss Function and its grad
    loss_fn = jit(lambda params: jnp.sum((apply_fn(params,x_train)-y_train)**2)/2)
    g_loss = jit(grad(loss_fn))

    avg_count = 0
    K_avg = []
    t_start = curr_time()
    for j in range(nT):
        opt_params = get_params(opt_state)
        opt_state = opt_update(j, g_loss(opt_params), opt_state)

        if j > burn_in:
            avg_count += 1
            for i, idx in enumerate(layers_ker):
                if j == burn_in + 1:
                    K_avg += [emp_kernel_fns[idx](opt_params[:idx+1])]
                else: 
                    K_avg[i] += emp_kernel_fns[idx](opt_params[:idx+1])

        if j % 1000 == 0:
            print('%d | loss: %f | avg_count: %d | time: %s'%(j, loss_fn(opt_params), avg_count, time_diff(t_start)), flush=True)
    
    K_nngp, K_theory, Gxx, Gyy = theory.theory_linear(x_train, y_train, beta, kernel_fns, hidden_width)
    K_nngps += [K_nngp]
    K_theories += [K_theory]

    with open('data_%s_%d_%s_%s_%s.pkl'%(str(hidden_width), N_tr, model_type, opt_mode, nonlin), 'wb') as outfile:
        pickle.dump({'K_avg': K_avg, 'K_nngp': K_nngp, 'K_theory': K_theory, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL) 

        
with open('data_%d_%s_%s_%s.pkl'%(N_tr, model_type, opt_mode, nonlin), 'wb') as outfile:
    pickle.dump({'K_avgs': K_avgs, 'K_nngps': K_nngps, 'K_theories': K_theories, 
                 'model_type': model_type, 'hidden_widths': hidden_widths, 'N_tr': N_tr, 
                 'nT': nT, 'beta': beta, 'batch_size': batch_size, 'step_size': step_size,
                 'avg_count': avg_count, 'opt_mode': opt_mode}, outfile, pickle.HIGHEST_PROTOCOL)

    

In [7]:
K_avgs = []
K_nngps = []
avg_counts = []
nTs = []

nonlin = 'linear'
data_files = ['data_%s_%d_%s_%s_%s.pkl'%(str(s), N_tr, model_type, opt_mode, nonlin) for s in hidden_widths]

for data_file in data_files:
    with open(data_file, 'rb') as infile:
        data = pickle.load(infile)
        N_tr = data['N_tr']
        model_type= data['model_type']
        K_avgs += [data['K_avg']]
        K_nngps += [data['K_nngp']]
        avg_counts += [data['avg_count']]
        nTs += [data['nT']]
        beta = data['beta']

## Preprocess NT kernels
for i, K_width in zip(np.arange(len(K_avgs)), K_avgs):
    for j, K in enumerate(K_width):
        K_avgs[i][j] =  convert_nt(K)/avg_count

In [8]:
depths = np.arange(len(K_avgs[0]))
deviation = []
deviation_th = []

for K_exp, K_nngp, K_theory, hidden_width in zip(K_avgs, K_nngps, K_theories, hidden_widths):
    dev_width = []
    dev_width_th = []
    for j, K, K_nn, K_th in zip(np.arange(len(K_exp)), K_exp, K_nngp, [*K_theory,0,0]):        
        if len(K.shape) == 6:
            K = K.reshape(N_tr, N_tr, resized**2, resized**2)
            K_nn = K_nn.reshape(N_tr, N_tr, resized**2, resized**2)
            K_th = K_th.reshape(N_tr, N_tr, resized**2, resized**2)
        dev_width += [np.linalg.norm(K - K_nn)]
        dev_width_th += [np.linalg.norm(K_th - K_nn)]
        
        
    deviation += [dev_width]
    deviation_th += [dev_width_th]

deviation = np.array(deviation)
deviation_th = np.array(deviation_th)

In [None]:
widths = [width[0] for width in hidden_widths]
deviation = np.array(deviation)
plt.figure(figsize=(6,5))
print(deviation.shape)
for l in range(2):
    plt.loglog(widths, deviation[:,l], 'o', label='layer %d'%(l+1), color = 'C%d'%l)


deviation_th = np.array(deviation_th)
for l in range(2):
    if l == 0:
        plt.loglog(widths, deviation_th[:,l],'--', color = 'C%d'%l)
    else:
        plt.loglog(widths, deviation_th[:,l],'--', color = 'C%d'%l)
        
plt.xlabel(r'Width', fontsize=20)
plt.ylabel(r'$||K_{exp} - K_{GP}||_F$', fontsize=20)
plt.gca().tick_params(axis='both', which = 'major', labelsize=14)
plt.gca().tick_params(axis='both', which = 'minor', labelsize=12)
plt.legend(fontsize=16)  
plt.tight_layout()
plt.savefig('one_over_width_%s_sgld.png'%model_type, dpi=600)

In [None]:
lay_idx = 1
img_idx1 = 14
img_idx2 = img_idx1

for i in range(len(K_avgs)):
    K_exp = [np.moveaxis(K, 3, 4).reshape(N_tr,N_tr,resized**2,resized**2) for K in K_avgs[0][:-2]]
    K_nngp_th = [np.moveaxis(K, 3, 4).reshape(N_tr,N_tr,resized**2,resized**2) for K in K_nngp[:-2]]
    K_th = [np.moveaxis(K, 3, 4).reshape(N_tr,N_tr,resized**2,resized**2) for K in K_theories[0]]
    
    vmin = np.min(K_exp[lay_idx][img_idx1,img_idx2])
    vmax = np.max(K_exp[lay_idx][img_idx1,img_idx2])
    
    print(np.abs(K_exp[lay_idx][img_idx1,img_idx2]/K_nngp_th[lay_idx][img_idx1,img_idx2]).mean())
    print(np.abs(K_th[lay_idx][img_idx1,img_idx2]/K_nngp_th[lay_idx][img_idx1,img_idx2]).mean())

    fig, axs = plt.subplots(1,3)
    fig.subplots_adjust(wspace=-0)
    fig.subplots_adjust(hspace=-0)
    
    axs[0].imshow(K_exp[lay_idx][img_idx1,img_idx2], cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axs[1].imshow(K_nngp_th[lay_idx][img_idx1,img_idx2], cmap='RdBu_r', vmin=vmin, vmax=vmax)
    im = axs[2].imshow((K_exp[lay_idx][img_idx1,img_idx2]-K_nngp_th[lay_idx][img_idx1,img_idx2])*10, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    cbar_ax = fig.add_axes([0.93, 0.32, 0.02, 0.35])
    fig.colorbar(im, cax=cbar_ax)
    
    plt.setp(axs[0].get_xticklabels(), visible=False)
    plt.setp(axs[0].get_yticklabels(), visible=False)
    plt.setp(axs[1].get_xticklabels(), visible=False)
    plt.setp(axs[1].get_yticklabels(), visible=False)
    plt.setp(axs[2].get_xticklabels(), visible=False)
    plt.setp(axs[2].get_yticklabels(), visible=False)
    axs[0].tick_params(axis='both', which='both', length=0)
    axs[1].tick_params(axis='both', which='both', length=0)
    axs[2].tick_params(axis='both', which='both', length=0)
    
    
    axs[0].set_title('$K_{exp}^{(%d)}$'%(lay_idx+1))
    axs[1].set_title('$K_{GP}^{(%d)}$'%(lay_idx+1))
    axs[2].set_title('$K_{exp}^{(%d)} - K_{GP}^{(%d)}$'%(lay_idx+1,lay_idx+1))
    
    plt.savefig('kernel_conv_at_layer_%d.png'%(lay_idx+1), dpi=600, bbox_inches='tight', pad_inches=0)
    plt.show()


In [None]:
data_limit = 4000

lay_idx = 0
xs = []
plt.figure(figsize=(6,6))
for width_idx in [0, 2, -1]:
    K_exp = [K.reshape(N_tr,N_tr,resized**2,resized**2) for K in K_avgs[width_idx][:2]]
    K_nngp_th = [K.reshape(N_tr,N_tr,resized**2,resized**2) for K in K_nngp[:2]]
    K_th = [K.reshape(N_tr,N_tr,resized**2,resized**2) for K in K_theories[width_idx]]
    
    x = (K_exp[lay_idx]-K_nngp_th[lay_idx]).reshape(-1)[:data_limit]
    y = (K_th[lay_idx]-K_nngp_th[lay_idx]).reshape(-1)[:data_limit]
    
    if width_idx == 0:
        plt.plot(np.linspace(min(x), max(x),1000), np.linspace(min(x), max(x),1000),'k--')
    
    print(np.mean(K_exp[lay_idx]/K_th[lay_idx]), np.abs(x/y).mean())
    plt.scatter(x,y, label='width = %d'%widths[width_idx])

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=16, loc='upper left')  
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
plt.xlabel('$K_{exp}^{(%d)} - K_{GP}^{(%d)}$'%(lay_idx + 1,lay_idx + 1), fontsize=20)
plt.ylabel('$K_{th}^{(%d)} - K_{GP}^{(%d)}$'%(lay_idx + 1,lay_idx + 1), fontsize=20)
plt.tight_layout()
plt.savefig('k-nngp_cov_cnn_layer_%d.png'%(lay_idx+1), dpi=600)
plt.show()

## MNIST Digits

In [None]:
Gxx_fcn = x_train.reshape(N_tr,-1) @ x_train.reshape(N_tr,-1).T / x_train.reshape(N_tr,-1).shape[1]
Gxx = np.moveaxis(np.tensordot(x_train, x_train, (3, 3)), (3,2), (1,4)) ## Tensordot in channel axis
Gyy = y_train @ y_train.T / y_train.shape[1]

idx = 12

plt.imshow(x_train[idx].squeeze())
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.title('$Digit: 2$', fontsize=22)
plt.savefig('single_mnist_image.png', dpi=600)
plt.show()

plt.imshow(Gxx_fcn)
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.title('$G_{xx}$', fontsize=22)
plt.savefig('gxx_fcn.png', dpi=600)
plt.show()

plt.imshow(Gxx[idx,idx].reshape(resized**2,resized**2))
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.title('$G_{xx}$ (Tensor)', fontsize=22)
plt.savefig('gxx_cnn.png', dpi=600)
plt.show()


plt.imshow(Gyy)
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.title('$G_{yy}$', fontsize=22)
plt.savefig('gyy_fcn.png', dpi=600)
plt.show()