# Preliminaries

In [None]:
import numpy as np
import matplotlib
import matplotlib.ticker as ticker
matplotlib.rcParams['text.usetex'] = True
import matplotlib.pyplot as plt
import copy
import json
import json_stream
from tqdm import tqdm
import scipy.linalg as linalg
import scipy.optimize as opt
import math

In [None]:
# define linear function for fitting
def lin_fit_fun(x,m,b):
    return m*x+b

# Read in data

In [None]:
# filepath of json file containing training data. File should contain the fields: 
# "learning_rate",
# "gamma" (learning rate decay factor),
# "margin",
# "momentum",
# "batch_size",
# "epochs",
# "input_size",
# "output_size",
# "depth",
# "width",
# "test_accuracy",
# "losses",
# "err_rates" (validation),
# "step" (number of epochs between weight measurement),
# "parameters" (a list of lists containing the weights at each measurement)

filepath="train_data/mnist_paper_relu_lr-4.json"

In [None]:
file=open(filepath,"r")
in_data=json_stream.load(file,persistent=True)

In [None]:
in_data["learning_rate"]

In [None]:
in_data["gamma"]

In [None]:
in_data["momentum"]

In [None]:
in_data["batch_size"]

In [None]:
1-in_data["test_accuracy"]

In [None]:
losses=in_data["losses"]

In [None]:
err_rates=in_data["err_rates"]

In [None]:
step=in_data["step"]
print(step)

In [None]:
epochs=in_data["epochs"]
print(epochs)
ep_arr=np.arange(0,epochs+1,step)

In [None]:
num_layers=in_data["depth"]+1
print(num_layers)

In [None]:
params=in_data["parameters"]
#convert to arrays
params=[[np.array(item) for item in t_param] for t_param in params]
print(len(params))

In [None]:
layers=[W.shape[1] for W in params[0]]+[params[0][num_layers-2].shape[0]]
N=sum(layers)
print(layers)

## loss and accuracy

In [None]:
# create a new epoch array and change the first element from 0 to 0.7 for plotting purposes
m_ep_arr=np.array(ep_arr,dtype=float)
m_ep_arr[0]=0.7

In [None]:
# fit log-log loss data to line, modify starting index as needed
lin_popt,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[10:]),np.log(losses[10:]))
print(lin_popt)

# plot losses
plt.figure(figsize=(8,6))
plt.plot(ep_arr,losses,linewidth=5,alpha=0.8,label='Data')
# plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*lin_popt)),'k--',label='Power-law fit')
plt.legend(fontsize=15)
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('Loss',fontsize=20)
plt.ylim((None,max(losses)+5))
# plt.savefig('paper_plots/loss_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# plot losses on log-symlog plot
plt.figure(figsize=(8,6))
ax=plt.gca()
plt.plot(m_ep_arr[:],losses[:],linewidth=5,alpha=0.8,label='Data')
# plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*lin_popt)),'k',linestyle='dashed',label='Power-law fit')
plt.yscale('log')
ax.set_xscale('symlog',linthresh=1)
plt.xticks(list(plt.xticks()[0]) + [m_ep_arr[0]],list(plt.xticks()[1])+['$0$'])
ll=ticker.SymmetricalLogLocator(linthresh=1,base=10)
ll.set_params([2,3,4,5,6,7,8,9])
ax.xaxis.set_minor_locator(ll)
plt.xlim([m_ep_arr[0]-0.1,600])
plt.legend(fontsize=15)
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('Loss',fontsize=20)
# plt.savefig('paper_plots/loss_lr-4_2_log.pdf',bbox_inches='tight')
plt.show()

In [None]:
# fit log-log validation error rate data in different regimes
# lin_popt,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[2:40]),np.log(err_rates[2:40]))
# print(lin_popt)

# lin_popt_2,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[60:]),np.log(err_rates[60:]))
# print(lin_popt_2)

# plot validation error rate
plt.figure(figsize=(8,6))
plt.plot(ep_arr,err_rates,linewidth=5,alpha=0.8,label='Data')
plt.plot(ep_arr,np.exp(lin_fit_fun(np.log(ep_arr),*lin_popt)),'k',linestyle='dashed',label='Power-law fit')
# plt.plot(ep_arr,np.exp(lin_fit_fun(np.log(ep_arr),*lin_popt_2)),'r',linestyle='dashed',label='Late Power-law fit')
plt.legend(fontsize=15)
plt.ylim([None,1])
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('Validation error rate',fontsize=20)
# plt.savefig('paper_plots/error_rate_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# plot validation error rate on log-symlog plot
plt.figure(figsize=(8,6))
ax=plt.gca()
plt.plot(m_ep_arr[:],err_rates[:],linewidth=5,alpha=0.8,label='Data')
plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*lin_popt)),'k',linestyle='dashed',label='Power-law fit')
# plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*lin_popt_2)),'r',linestyle='dashed',label='Late Power-law fit')
plt.yscale('log')
ax.set_xscale('symlog',linthresh=1)
ll=ticker.SymmetricalLogLocator(linthresh=1,base=10)
ll.set_params([2,3,4,5,6,7,8,9])
ax.xaxis.set_minor_locator(ll)
plt.xticks(list(plt.xticks()[0]) + [m_ep_arr[0]],list(plt.xticks()[1])+['$0$'])
plt.legend(fontsize=15)
# plt.ylim([None,1])
plt.xlim([m_ep_arr[0]-0.1,600])
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('Validation error rate',fontsize=20)
# plt.savefig('paper_plots/error_rate_lr-4_2_log.pdf',bbox_inches='tight')
plt.show()

# The bond matrix

In [None]:
# define bond matrix using weights at a specific training time
def bmat(tparams):
    bl_rows=[]
    for i in range(num_layers):
        row=[]
        for j in range(num_layers):
            if j==i-1:
                row.append(tparams[j])
            else:
                row.append(np.zeros((layers[i],layers[j])))
        bl_rows.append(row)
    bmat=np.block(bl_rows)
    return bmat+bmat.T

## Spectrum

In [None]:
# choose two training times to compare. Generally start and end of training.
ep_start=0
ep_end=-1

# create bond matrices at those times
B_start=bmat(params[ep_start])
B_end=bmat(params[ep_end])

# find eigenvalues an eigenvectors
eigs_start,evs_start=linalg.eigh(B_start)
evs_start=evs_start.T
eigs_end,evs_end=linalg.eigh(B_end)
evs_end=evs_end.T

In [None]:
# write results to file
out_dict={
    "ref_file": filepath,
    "eigs_start": list(eigs_start),
    "eigs_end": list(eigs_end)
}

with open("train_data/Jevals_mnist_relu_paper_lr-4_2.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# use this cell to read in previously computed results
with open("train_data/Jevals_mnist_paper_relu_lr-4.json","r") as file:
    in_dict=json.load(file)

eigs_start=np.array(in_dict["eigs_start"])
eigs_end=np.array(in_dict["eigs_end"])

In [None]:
# plot bond matrix spectrum zoomed in to neglect central spike
counts_end, bins,_=plt.hist(eigs_end,bins=100, density=True)
plt.close()
plt.figure(figsize=(8,6))
counts_start,_,_=plt.hist(eigs_start,bins=bins, density=True,alpha=0.5,label="Before Training")
counts_end, bins,_=plt.hist(eigs_end,bins=bins, density=True,alpha=0.5,label="After Training")
plt.xlim([-7,7])
plt.ylim([0,0.3])
plt.legend(fontsize=15)
plt.xlabel('Eigenvalues of $J$',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
# plt.savefig('paper_plots/spectrum_hist_relu_lr-4_zoom.pdf',bbox_inches='tight')
plt.show()

# plot bond matrix spectrum
counts_end, bins,_=plt.hist(eigs_end,bins=100, density=True)
plt.close()
plt.figure(figsize=(8,6))
counts_start,_,_=plt.hist(eigs_start,bins=bins, density=True,alpha=0.5,label="Before Training")
counts_end, bins,_=plt.hist(eigs_end,bins=bins, density=True,alpha=0.5,label="After Training")
plt.xlim([-7,7])
plt.legend(fontsize=15)
plt.xlabel('Eigenvalues of $J$',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
# plt.savefig('paper_plots/spectrum_hist_relu_lr-4.pdf',bbox_inches='tight')
plt.show()

# plot bond matrix spectrum on log plot
counts_end, bins,_=plt.hist(eigs_end,bins=100, density=True)
plt.close()
plt.figure(figsize=(8,6))
counts_start,_,_=plt.hist(eigs_start,bins=bins, density=True,alpha=0.5,label="Before Training")
counts_end, bins,_=plt.hist(eigs_end,bins=bins, density=True,alpha=0.5,label="After Training")
plt.xlim([-7,7])
plt.yscale('log')
plt.legend(fontsize=15)
plt.xlabel('Eigenvalues of $J$',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
# plt.savefig('paper_plots/spectrum_hist_relu_lr-4_log.pdf',bbox_inches='tight')
plt.show()

## row-wise sum of squares of bond matrix elements

In [None]:
# compute row-wise sums of squares at times defined above
J2s_start=np.sum(B_start**2,axis=0)
J2s_end=np.sum(B_end**2,axis=0)

In [None]:
# write results to file
out_dict={
    "ref_file": filepath,
    "J2s_start": list(J2s_start),
    "J2s_end": list(J2s_end)
}

with open("train_data/J2s_mnist_paper_lr-4_2.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# create histogram and then discard to define the bins
counts_end, bins,_=plt.hist(eigs_end,bins=100, density=True)
plt.close()

# plot histogram of sums
counts_end, bins,_=plt.hist(eigs_end,bins=100, density=True)
plt.close()
plt.figure(figsize=(8,6))
counts_start,_,_=plt.hist(J2s_start,bins=bins, density=True,alpha=0.5,label="Before Training")
counts_end, bins,_=plt.hist(J2s_end,bins=bins, density=True,alpha=0.5,label="After Training")
# plt.xlim([-7,7])
plt.legend(fontsize=15)
plt.xlabel('$J^2$',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
# plt.savefig('paper_plots/spectrum_hist_lr-7.pdf',bbox_inches='tight')
plt.show()

#plot histogram of sums on log plot
counts_end, bins,_=plt.hist(eigs_end,bins=100, density=True)
plt.close()
plt.figure(figsize=(8,6))
counts_start,_,_=plt.hist(J2s_start,bins=bins, density=True,alpha=0.5,label="Before Training")
counts_end, bins,_=plt.hist(J2s_end,bins=bins, density=True,alpha=0.5,label="After Training")
# plt.xlim([-7,7])
plt.yscale('log')
plt.legend(fontsize=15)
plt.xlabel('$J^2$',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
# plt.savefig('paper_plots/spectrum_hist_lr-7_log.pdf',bbox_inches='tight')
plt.show()

In [None]:
# calculate the minimum row-wise sum of squares at each time step
mins=[]
for param in tqdm(params):
    J=bmat(param)
    mins.append(min(np.sum(J**2,axis=0)))

In [None]:
# write results to file
out_dict={
    "ref_file": filepath,
    "J2_mins": list(mins)
}

with open("train_data/J2mins_mnist_paper_lr-4_2.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# use this cell to read in the results from a file
with open("train_data/J2mins_mnist_relu_paper_lr-4.json","r") as file:
    in_dict=json.load(file)

mins=np.array(in_dict["J2_mins"])

In [None]:
# plot minimum across training time to see how much it changes
plt.figure(figsize=(8,6))
plt.plot(ep_arr,mins)
# plt.xscale('log')
# plt.yscale('log')
plt.show()

In [None]:
# calculate the ratio of the change across training to original value
(mins[-1]-mins[0])/mins[0]

## Maximum eigenvalue

In [None]:
# calculate the maximum eigenvalue of bond matrix across training
max_eigs=[]
for param in tqdm(params):
    J=bmat(param)
    eig=max(linalg.eigvalsh(J))
    max_eigs.append(eig)

In [None]:
# write results to file
out_dict={
    "ref_file": filepath,
    "max_eigs": list(max_eigs)
}

with open("train_data/maxs_mnist_paper_lr-4_2.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# use this cell to read in the results from a file
with open("train_data/maxs_mnist_paper_lr-4_2.json","r") as file:
    in_dict=json.load(file)

max_eigs=np.array(in_dict["max_eigs"])

In [None]:
# fit log-log max eigenvalue data in different regimes
lin_popt,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[10:80]),np.log(max_eigs[10:80]))
print(lin_popt)

lin_popt_2,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[300:]),np.log(max_eigs[300:]))
print(lin_popt_2)

# plot maximum eigenvalue across training
plt.figure(figsize=(8,6))
plt.plot(ep_arr, max_eigs,linewidth=5,alpha=0.8,label='Data')
plt.plot(ep_arr,np.exp(lin_fit_fun(np.log(ep_arr),*lin_popt)),'k',linestyle='dashed',label='Power-law fit')
plt.plot(ep_arr,np.exp(lin_fit_fun(np.log(ep_arr),*lin_popt_2)),'r',linestyle='dashed',label='Late Power-law fit')
plt.ylim([2.4,None])
# plt.xlim([0,20])
plt.legend(fontsize=15)
plt.ylabel('$\\lambda_{\\max}(J)$',fontsize=20)
plt.xlabel('Epoch',fontsize=20)
# plt.savefig('paper_plots/max_eig_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# plot maximum eigenvalue across training on log-symlog plot
plt.figure(figsize=(8,6))
ax=plt.gca()
plt.plot(m_ep_arr[:], max_eigs[:],linewidth=5,alpha=0.8,label='Data')
plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*lin_popt)),'k',linestyle='dashed',label='Power-law fit')
plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*lin_popt_2)),'r',linestyle='dashed',label='Late Power-law fit')
plt.ylim([2.5,None])
plt.yscale('log')
ax.set_xscale('symlog',linthresh=1)
ll=ticker.SymmetricalLogLocator(linthresh=1,base=10)
ll.set_params([2,3,4,5,6,7,8,9])
ax.xaxis.set_minor_locator(ll)
plt.xticks(list(plt.xticks()[0]) + [m_ep_arr[0]],list(plt.xticks()[1])+['$0$'])
plt.xlim([m_ep_arr[0]-0.1,600])
plt.legend(fontsize=15)
plt.ylabel('$\\lambda_{\\max}(J)$',fontsize=20)
plt.xlabel('Epoch',fontsize=20)
# plt.savefig('paper_plots/max_eig_lr-4_2_log.pdf',bbox_inches='tight')
plt.show()

# Eigenvalues of M (defined in TAP equations)

In [None]:
# define function to find eigenvalues of the matrix defined by the linearized TAP equations given an inverse temp and bond matrix
def eigsVbeta(betas,bmat):
    eig_list=[]
    for beta in tqdm(betas):
        Mat=copy.deepcopy(bmat)
        for i in range(len(bmat)):
            Mat[i,i]=-beta*np.sum(Mat[i,:]**2)
        Mat=beta*Mat

        eig_list.append(linalg.eigh(Mat,eigvals_only=True))

    return np.array(eig_list)

In [None]:
# set two training times to compare. Usually start and end of training
ep_start=0
ep_end=-1

# create bond matrices at those times
B_start=bmat(params[ep_start])
B_end=bmat(params[ep_end])

In [None]:
# define array of betas to use for the calculation
betas=np.linspace(0,1.2,20)

# find eigenvalues
eig_list_start=eigsVbeta(betas,B_start)
eig_list_end=eigsVbeta(betas,B_end)

In [None]:
# write results to file
out_dict={
    "ref_file": filepath,
    "eig_list_start": [list(eig_list) for eig_list in eig_list_start],
    "eig_list_end": [list(eig_list) for eig_list in eig_list_end],
    "betas": list(betas)
}

with open("train_data/eig_min_ba_mnist_paper_lr-4_2.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# use this cell to read in results from a file
with open("train_data/eig_min_ba_mnist_paper_lr-4_2.json","r") as file:
    in_dict=json.load(file)

eig_list_start=np.array(in_dict["eig_list_start"])
eig_list_end=np.array(in_dict["eig_list_end"])
betas=np.array(in_dict["betas"])

In [None]:
# plot eigs v beta at first time
plt.figure(figsize=(8,6))
plt.plot(betas,eig_list_start)
plt.axhline(1)
plt.ylim([-1,2])
plt.xlabel('$\\beta$')
plt.ylabel('eigenvalues')
#plt.text(4,-0.5,f'epoch {ep_num*5}')
plt.show()

# plot eigs v beta at second time
plt.figure(figsize=(8,6))
plt.plot(betas,eig_list_end)
plt.axhline(1)
plt.ylim([-1,2])
plt.xlabel('$\\beta$')
plt.ylabel('eigenvalues')
#plt.text(4,-0.5,f'epoch {ep_num*5}')
plt.show()

In [None]:
# plot minimum eig of 1-M before and after training
plt.figure(figsize=(8,6))
plt.plot(betas,1-eig_list_start[:,-1],label='Before training')
plt.plot(betas,1-eig_list_end[:,-1],label='After training')
plt.axhline(0,linestyle='dashed',c='gray')
plt.axvline(0.8,linestyle='dotted',c='gray')
plt.axvline(beta_ts[-1],linestyle='dotted',c='gray')
plt.ylabel('$\\lambda_{\\min}(I_N-M)$',size=20)
plt.xlabel('$\\beta$',size=20)
plt.legend(fontsize=15)
# plt.ylim([-0.1,0.1])
# plt.savefig('paper_plots/min_eig_evo_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# Transition temperature

In [None]:
# define quadratic function for fitting
def quad_fit_fun(x,a,b,c):
    return -a*(x-b)**2+c

In [None]:
# define ending index (generally the end of training), array of betas, and initialize the crossover value to 1.
end_idx=len(params)
betas=np.linspace(0.1,0.9,20)
cross_val=1
beta_ts=[]

# start with weights before training
tparams=params[0]
l_evals=[]
# calculate M at each beta
for beta in betas:
    Mat=bmat(tparams)
    for i in range(N):
        Mat[i,i]=-beta*np.sum(Mat[i,:]**2)
    Mat=beta*Mat
    # find largest eigenvector
    l_evals.append(linalg.eigh(Mat,eigvals_only=True,subset_by_index=[N-1,N-1])[0])

# find at which beta the largest eigenvalue first becomes greater than one
# if this does not happen, take the beta which maximizes the curve
cross_idx=None
for i in range(len(betas)):
    if l_evals[i]>=1:
        cross_idx=i
        break
if cross_idx==None:
    cross_idx=np.argmax(l_evals)
# find a quadratic fit around the selected beta
qpopt,_=opt.curve_fit(quad_fit_fun,betas[cross_idx-2:cross_idx+2],l_evals[cross_idx-2:cross_idx+2])
a,b,c=qpopt
# set the crossover value to the maximum of this quadratic fit
cross_val=c
print(cross_val)
# store the value of beta corresponding to the maximum as the transition temperature
beta_ts.append(b)

# carry through the transition temp calulation for the other times using the new crossover value
for k in tqdm(range(1,end_idx)):
    tparams=params[k]
    l_evals=[]
    for beta in betas:
        Mat=bmat(tparams)
        for i in range(N):
            Mat[i,i]=-beta*np.sum(Mat[i,:]**2)
        Mat=beta*Mat
        
        l_evals.append(linalg.eigh(Mat,eigvals_only=True,subset_by_index=[N-1,N-1])[0])
    
    cross_idx=None
    for i in range(len(betas)):
        if l_evals[i]>=1:
            cross_idx=i
            break
    if cross_idx==None:
        cross_idx=np.argmax(l_evals)
    qpopt,_=opt.curve_fit(quad_fit_fun,betas[cross_idx-2:cross_idx+2],l_evals[cross_idx-2:cross_idx+2])
    a,b,c=qpopt
    if c>=cross_val:
        beta_ts.append(b-np.sqrt((c-cross_val)/a))
    else:
        beta_ts.append(b-1j*np.sqrt((cross_val-c)/a))   
    
beta_ts=np.array(beta_ts)

In [None]:
# write the results to a file
out_dict={
    "ref_file": filepath,
    "beta_ts_real": list(np.real(beta_ts)),
    "beta_ts_imag": list(np.imag(beta_ts))
}

with open("train_data/betas_mnist_relu_paper_lr-4_adj.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# use this cell to read in the results from a file
with open("train_data/betas_mnist_paper_lr-6_adj.json","r") as file:
    in_dict=json.load(file)

beta_ts=np.array(in_dict["beta_ts_real"])+1j*np.array(in_dict["beta_ts_imag"])

In [None]:
# fit log-log transition temp data to line in different regimes
popt,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[10:80]),np.log(np.array(1/beta_ts[10:80])))
print(popt)

popt_2,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[300:]),np.log(np.array(1/beta_ts[300:])))
print(popt)

# plot transition temperature
plt.figure(figsize=(8,6))
plt.plot(ep_arr, 1/beta_ts,linewidth=5,alpha=0.8,label='Data')
plt.plot(ep_arr[0:],np.exp(lin_fit_fun(np.log(ep_arr[0:]),*popt)),'k',linestyle='dashed',label='Power-law fit')
plt.plot(ep_arr[0:],np.exp(lin_fit_fun(np.log(ep_arr[0:]),*popt_2)),'r',linestyle='dashed',label='Late Power-law fit')
plt.ylim([1/0.9,None])
plt.ylabel("$T_c$",fontsize=20)
plt.xlabel("Epoch",fontsize=20)
plt.legend(fontsize=15)
# plt.xlim([None,120])
# plt.xlim([0,20])
# plt.savefig('paper_plots/ttemp_lr-4_2.pdf',bbox_inches = "tight")
plt.show()

# plot transition temperataure on log-symlog plot
plt.figure(figsize=(8,6))
ax=plt.gca()
plt.plot(m_ep_arr[:], 1/beta_ts[:],linewidth=5,alpha=0.8,label='Data')
plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*popt)),'k',linestyle='dashed',label='Power-law fit')
plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*popt_2)),'r',linestyle='dashed',label='Late Power-law fit')
plt.ylabel("$T_c$",fontsize=20)
plt.xlabel("Epoch",fontsize=20)
plt.yscale('log')
ax.set_xscale('symlog',linthresh=1)
ll=ticker.SymmetricalLogLocator(linthresh=1,base=10)
ll.set_params([2,3,4,5,6,7,8,9])
ax.xaxis.set_minor_locator(ll)
plt.xticks(list(plt.xticks()[0]) + [m_ep_arr[0]],list(plt.xticks()[1])+['$0$'])
plt.xlim([m_ep_arr[0]-0.1,600])
plt.ylim([1/0.9,None])
plt.legend(fontsize=15)
# plt.savefig('paper_plots/ttemp_lr-4_2_log.pdf',bbox_inches = "tight")
plt.show()

## Eigenvalues of M at transition temps

In [None]:
# calculate eigenvalues of M at the transition temperature before and after training
beta_t_start=beta_ts[0]
beta_t_end=beta_ts[-1]

param_start=params[0]
param_end=params[-1]

Mat_start=bmat(param_start).astype(complex)
for i in range(N):
    Mat_start[i,i]=-beta_t_start*np.sum(Mat_start[i,:]**2)
Mat_start=beta_t_start*Mat_start
Mevals_start=linalg.eigh(Mat_start,eigvals_only=True)

Mat_end=bmat(param_end).astype(complex)
for i in range(N):
    Mat_end[i,i]=-beta_t_end*np.sum(Mat_end[i,:]**2)
Mat_end=beta_t_end*Mat_end
Mevals_end=linalg.eigh(Mat_end,eigvals_only=True)

print(max(Mevals_start))
print(max(Mevals_end))

In [None]:
# write results to a file
out_dict={
    "ref_file": filepath,
    "Mevals_start": list(Mevals_start),
    "Mevals_end": list(Mevals_end)
}

with open("train_data/tevals_mnist_paper_lr-4_2.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
bin_num=100

# plot eigenvalues of 1-M zoomed in to ignore sharp peaks
plt.figure(figsize=(8,6))
counts_start,bins,_=plt.hist(1-Mevals_start,bins=bin_num,alpha=0.5,density=True,label="Before Training")
counts_end,_,_=plt.hist(1-Mevals_end,bins=bins,alpha=0.5,density=True,label="After Training")
plt.ylim((0,1.5))
# plt.xlim((0,1.1))
plt.xlabel('Eigenvalues of $I_N-M$ at Transition Temperature',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
plt.legend(fontsize=15)
# plt.savefig('paper_plots/I-M_before_after_lr-4_2_zoom.pdf',bbox_inches='tight')
plt.show()

# plot eigenvalues of 1-M
plt.figure(figsize=(8,6))
counts_start,bins,_=plt.hist(1-Mevals_start,bins=bin_num,alpha=0.5,density=True,label="Before Training")
counts_end,_,_=plt.hist(1-Mevals_end,bins=bins,alpha=0.5,density=True,label="After Training")
plt.xlabel('Eigenvalues of $I_N-M$ at Transition Temperature',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
plt.legend(fontsize=15)
# plt.savefig('paper_plots/I-M_before_after_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# plot eigenvalues of 1-M on log plot
plt.figure(figsize=(8,6))
counts_start,bins,_=plt.hist(1-Mevals_start,bins=bin_num,alpha=0.5,density=True,label="Before Training")
counts_end,_,_=plt.hist(1-Mevals_end,bins=bins,alpha=0.5,density=True,label="After Training")
plt.xlabel('Eigenvalues of $I_N-M$ at Transition Temperature',fontsize=20)
plt.ylabel('Spectral Density',fontsize=20)
plt.yscale('log')
plt.legend(fontsize=15)
# plt.savefig('paper_plots/I-M_before_after_lr-4_2_log.pdf',bbox_inches='tight')
plt.show()

print(counts_start[0])
print(counts_end[0])

# Evolution of first level spacing of 1-M

In [None]:
# calculate the first level spacing and normalized first level spacing across training
gaps=[]
norm_gaps=[]

for j in tqdm(range(len(beta_ts))):
    tparams=params[j]
    beta=beta_ts[j]
    Mat=bmat(tparams).astype(complex)
    for i in range(N):
        Mat[i,i]=-beta*np.sum(Mat[i,:]**2)
    Mat=beta*Mat
    evals=linalg.eigh(Mat,eigvals_only=True)
    spacings=[evals[i+1]-evals[i] for i in range(len(evals)-1)]
    gaps.append(spacings[-1])
    norm_gaps.append(spacings[-1]/np.median(spacings))

In [None]:
# write results to a file
out_dict={
    "ref_file": filepath,
    "gaps": gaps,
    "norm_gaps": norm_gaps
}

with open("train_data/gaps_mnist_paper_lr-4.json","w") as outfile:
#     json.dump(out_dict,outfile)

In [None]:
# use this cell to read in the results from a file
with open("train_data/gaps_mnist_paper_lr-4_2.json","r") as file:
    in_dict=json.load(file)

gaps=in_dict['gaps']
norm_gaps=in_dict['norm_gaps']

In [None]:
# fit log-log first spacing data to linear function
popt,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[10:40]),np.log(np.array(gaps[10:40])))
print(popt)

# plot first level spacing
plt.figure(figsize=(8,6))
plt.plot(ep_arr,gaps,alpha=0.8,label="Data")
# plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*popt)),'k',linestyle='dashed',label='Power-law fit')
plt.legend(fontsize=15)
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('$s_{0}$',fontsize=20)
# plt.savefig('paper_plots/gaps_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# plot first level spacing on log-log plot
plt.figure(figsize=(8,6))
plt.plot(ep_arr[1:],gaps[1:],alpha=0.8,label="Data")
# plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*popt)),'k',linestyle='dashed',label='Power-law fit')
plt.axhline(gaps[0],linestyle='dotted',color='gray')
plt.xscale('log')
plt.yscale('log')
# plt.ylim([.025,None])
plt.legend(fontsize=15)
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('$s_{0}$',fontsize=20)
# plt.savefig('paper_plots/gaps_lr-4_2_log.pdf',bbox_inches='tight')
plt.show()

In [None]:
# fit log-log normalized first level spacing to linear function
popt,_=opt.curve_fit(lin_fit_fun,np.log(ep_arr[250:]),np.log(np.array(norm_gaps[250:])))
print(popt)

# plot normalized first level spacing
plt.figure(figsize=(8,6))
plt.plot(ep_arr,norm_gaps,linewidth=3,alpha=0.8,label="Data")
# plt.plot(ep_arr,np.exp(lin_fit_fun(np.log(ep_arr),*popt)),'k',linestyle='dashed',label='Power-law fit')
# plt.ylim([None,120])
plt.legend(fontsize=15)
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('$s_0/s_{\\mathrm{typ}}$',fontsize=20)
# plt.savefig('paper_plots/norm_gaps_lr-4_2.pdf',bbox_inches='tight')
plt.show()

# plot normalized first level spacing on log-symlog plot
plt.figure(figsize=(8,6))
ax=plt.gca()
plt.plot(m_ep_arr[:],norm_gaps[:],linewidth=3,alpha=0.8,label="Data")
# plt.plot(ep_arr[1:],np.exp(lin_fit_fun(np.log(ep_arr[1:]),*popt)),'k',linestyle='dashed',label='Power-law fit')
plt.yscale('log')
ax.set_xscale('symlog',linthresh=1)
ll=ticker.SymmetricalLogLocator(linthresh=1,base=10)
ll.set_params([2,3,4,5,6,7,8,9])
ax.xaxis.set_minor_locator(ll)
plt.xticks(list(plt.xticks()[0]) + [m_ep_arr[0]],list(plt.xticks()[1])+['$0$'])
plt.xlim([m_ep_arr[0]-0.1,600])
plt.legend(fontsize=15)
plt.xlabel('Epoch',fontsize=20)
plt.ylabel('$s_0/s_{\\mathrm{typ}}$',fontsize=20)
# plt.savefig('paper_plots/norm_gaps_lr-4_2_log.pdf',bbox_inches='tight')
plt.show()