In [1]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.8

import numpy as np
import jax.numpy as jnp
import jax.scipy as jsc
import matplotlib.pyplot as plt
from jax import random
from scipy.linalg import circulant
from jax import grad, jit, vmap, value_and_grad
import optax
import matplotlib
matplotlib.rcParams.update({'font.size': 15})

import os
import sys

file = open("../../../../meta.txt")
home = file.read().replace("\n", "/")
file.close()
    
p = os.path.abspath(home+'mypylib')
if p not in sys.path:
    sys.path.append(p)

import chunGP as gp

def loadData(dataName,N,P,seed=0):
    np.random.seed(seed)

    All=True
    if dataName=='MNIST':
        X,_,t_train,_,Y,_,_,_,_,_=\
            gp.getMNIST(N_train=N,N_test=100,normalize=True,seed=seed,All=All)
    if dataName=='Fashion':
        X,_,t_train,_,Y,_,_,_,_,_=\
            gp.getFashion(N_train=N,N_test=100,normalize=True,seed=seed,All=All,home=home)
    if dataName=='CIFAR':
        grayscale=False
        X,_,t_train,_,Y,_,_,_,_,_=\
            gp.getCIFAR(N_train=N,N_test=100,normalize=True,seed=seed,grayscale=grayscale,All=All,home=home)
    if dataName=='CIFARG':
        grayscale=True
        X,_,t_train,_,Y,_,_,_,_,_=\
            gp.getCIFAR(N_train=N,N_test=100,normalize=True,seed=seed,grayscale=grayscale,All=All,home=home)
    X=X/np.sqrt(np.sum(np.square(X),axis=1))[:,None]

    train_idx=np.random.choice(np.arange(N),P,replace=False)

    train_ID=np.ones(N)==0
    train_ID[train_idx]=True
    test_ID=np.invert(train_ID)

    x_train=X[train_ID,:]
    x_test=X[test_ID,:]

    y_train=Y[train_ID,:]
    y_test=Y[test_ID,:]
    
    return X,Y,x_train,x_test,y_train,y_test,train_ID,test_ID



env: XLA_PYTHON_CLIENT_MEM_FRACTION=.8


In [2]:
@jit
def accuracy(y_pred,y):
    return jnp.mean(y_pred == y)
def getNN(yr_test,y_class):
    pdist=pairwise_distances(yr_test, y_class, metric='euclidean')
    tp_test=jnp.argmin(pdist,axis=1)
    return y_class[tp_test,:],tp_test
def getMax(yr_test):
    tp_test=jnp.argmax(yr_test,axis=1)
    return tp_test
    
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    #scale=1/m
    w_key, _ = random.split(key)
    return scale/jnp.sqrt(m) * random.normal(w_key, (m, n))#, scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key,scale=1e-2):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k,scale=scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def init_network_params_save(sizes, key,scale=1e-2):
    keys = random.split(key, len(sizes))
    i=0
    for m, n, k in zip(sizes[:-1], sizes[1:], keys):
        W=random_layer_params(m, n, k,scale=scale)
        np.save('./finite_weights/{}'.format(i),W)
        i=i+1
        del W

def ReLU(A): return jnp.where(A>0,A,0)
@jit
def predict(params, X, pa):
    # per-example predictions
    XL = X
    for W in params[:-1]:
        H = jnp.matmul(XL, W)
        #b=jnp.quantile(H,1-pa)
        #XL = ReLU(H-b)
        b=jnp.quantile(H,1-pa,axis=1)
        XL = ReLU(H-b[:,None])
        #b=jnp.std(H)*gp.getTau(pa)
        #XL = ReLU(H-b)
        #b=jnp.std(H,axis=1)*gp.getTau(pa)
        #XL = ReLU(H-b[:,None])
    final_W = params[-1]
    return jnp.matmul(XL, final_W)
#@jit
def predict_load(L,final_W, X, pa,scale=1):
    # per-example predictions
    XL = X
    for i in np.arange(L-1):
        W=np.load('./finite_weights/{}.npy'.format(i))
        H = jnp.matmul(XL, W)*scale
        b=jnp.quantile(H,1-pa,axis=1)
        XL = ReLU(H-b[:,None])
        del W
    return jnp.matmul(XL, final_W)
@jit
def predictLast(params, X, pa):
    # per-example predictions
    XL = X
    for W in params[:-1]:
        H = jnp.matmul(XL, W)
        #b=jnp.quantile(H,1-pa)
        #XL = ReLU(H-b)
        b=jnp.quantile(H,1-pa,axis=1)
        XL = ReLU(H-b[:,None])
        #b=jnp.std(H)*gp.getTau(pa)
        #XL = ReLU(H-b)
        #b=jnp.std(H,axis=1)*gp.getTau(pa)
        #XL = ReLU(H-b[:,None])
    return XL
#@jit
def predictLast_load(L, X, pa,scale=1):
    # per-example predictions
    XL = X
    for i in np.arange(L-1):
        W=np.load('./finite_weights/{}.npy'.format(i))
        H = jnp.matmul(XL, W)*scale
        b=jnp.quantile(H,1-pa,axis=1)
        XL = ReLU(H-b[:,None])
        del W
    return XL
@jit
def fiterror(params, pa, X, Y):
    Yp = predict(params, X, pa)
    acc=accuracy(getMax(Yp),getMax(Y))
    return jnp.mean(jnp.square(Y-Yp)),acc
#@jit
def fiterror_load(L,final_W, pa,X,Y,scale=1):
    Yp = predict_load(L,final_W,X, pa,scale=scale)
    acc=accuracy(getMax(Yp),getMax(Y))
    return jnp.mean(jnp.square(Y-Yp)),acc

def getosig(tau):
    I0=gp.I2(np.array([0]),np.array([tau]),n=10000)
    osig=np.sqrt(np.pi/(I0-tau*np.sqrt(2*np.pi)))
    return osig

@jit
def trainLast(XL,y_train,s0):
    #return jnp.matmul(jnp.matmul(XL.T,jnp.linalg.inv(jnp.matmul(XL,XL.T)+s0*jnp.eye(jnp.shape(y_train)[0]))),y_train)
    return jnp.matmul(jnp.matmul(XL.T,jnp.linalg.pinv(jnp.matmul(XL,XL.T))),y_train)
    #return jnp.matmul(jnp.linalg.pinv(XL),y_train)



In [None]:
def runTrial(dataName,N,P,seed,X,Y,x_train,y_train,train_ID,wwidth=1000,s0=0.0,save=False):
    Ls=jnp.arange(2,20)
    #pas=[0.4999,0.2,0.1]
    pas=jnp.linspace(0.4999,0.01,20)

    layer_sizes=[np.shape(x_train)[1]]+[wwidth for i in range(np.max(Ls)-1)]+[np.shape(y_train)[1]]
    init_network_params_save(layer_sizes, random.PRNGKey(0),scale=1)
        
    #wwidth=1000

    #factor=0.0#1e-1
    #layer_sizes=[2,wwidth,wwidth,wwidth,wwidth,1]
    GE=np.zeros((jnp.shape(pas)[0],jnp.shape(Ls)[0]))
    GA=np.zeros((jnp.shape(pas)[0],jnp.shape(Ls)[0]))
    for j, pa in enumerate(pas):
        scale=getosig(gp.getTau(pa))
        print(j)
        for g, L in enumerate(Ls):
            print(g)
            #layer_sizes=[np.shape(x_train)[1]]+[wwidth for i in range(L-1)]+[np.shape(y_train)[1]]
            #params=init_network_params(layer_sizes, random.PRNGKey(0),scale=scale)
            #XL=predictLast(params, x_train, pa)
            #params[-1]=trainLast(XL,y_train,s0)
            #GE[j,g],GA[j,g]=fiterror(params, pa, X, Y)
            
            XL=predictLast_load(L, x_train, pa,scale=scale)
            final_W=trainLast(XL,y_train,s0)
            GE[j,g],GA[j,g]=fiterror_load(L, final_W, pa, X, Y,scale=scale)            
    if save:
        filename='finite_'+dataName+'_P{}_N{}_wwidth{}'.format(P,N,wwidth)
        np.savez('./sweep_data/'+filename,P=P,N=N,wwidth=wwidth,s0=s0,Ls=Ls,pas=pas,errAf=GE,acc=GA)
    return GE,GA
        
Ps=np.ceil(np.power(10,np.linspace(2,4,6))).astype(int)
Ns=2*Ps

dataNames=['MNIST','Fashion','CIFAR','CIFARG']


wwidth=15000
s0=0

#Lmax=19
#layer_sizes=[np.shape(x_train)[1]]+[wwidth for i in range(Lmax-1)]+[np.shape(y_train)[1]]
#init_network_params_save(layer_sizes, random.PRNGKey(0),scale=1)
    
#dataName=dataNames[-2]
#N=Ns[4]
#P=Ps[4]

for dataName in dataNames:
    print('dataName')
    #for i,P in enumerate(Ps[4:6]):
    for P in Ps[1:2]:
        #N=Ns[i]
        N=2*P
        seed=0
        X,Y,x_train,x_test,y_train,y_test,train_ID,test_ID=loadData(dataName,N,P,seed=seed)


        GE,GA=runTrial(dataName,N,P,seed,X,Y,x_train,y_train,train_ID,wwidth=wwidth,s0=s0,save=True)

        

dataName
Downloading train-images-idx3-ubyte.gz...
Downloading t10k-images-idx3-ubyte.gz...
Downloading train-labels-idx1-ubyte.gz...
Downloading t10k-labels-idx1-ubyte.gz...
Download complete.
Save complete.
# of training images:60000
# of test images:10000
0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
1
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
2
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
3
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
4
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
5
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
6
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
7
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
8
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
9
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
10
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
11
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
12
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
13
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
14
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
15
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

In [None]:
minid=np.argmin(GE,axis=0)

fig,ax=plt.subplots(1,1,figsize=(5,4))
im=ax.pcolor(Ls,pas,GE,cmap='jet',vmax=0.1,vmin=0.05)
ax.scatter(Ls,pas[minid],c='w',marker='D',s=10)
cb=fig.colorbar(im, ax=ax)
cb.set_label('err experiment')
ax.set_xlabel('L')
ax.set_ylabel('f')

fig,ax=plt.subplots(1,1,figsize=(5,4))
im=ax.pcolor(Ls,pas,GA,cmap='jet')#,vmax=0.1,vmin=0)
#ax.scatter(Ls,pas[minid],c='w',marker='D',s=10)
cb=fig.colorbar(im, ax=ax)
cb.set_label('err experiment')
ax.set_xlabel('L')
ax.set_ylabel('f')



In [None]:
minid=np.argmin(GE,axis=0)

fig,ax=plt.subplots(1,1,figsize=(5,4))
im=ax.pcolor(Ls,pas,GE,cmap='jet',vmax=0.1,vmin=0)
ax.scatter(Ls,pas[minid],c='w',marker='D',s=10)
cb=fig.colorbar(im, ax=ax)
cb.set_label('err experiment')
ax.set_xlabel('L')
ax.set_ylabel('f')

fig,ax=plt.subplots(1,1,figsize=(5,4))
im=ax.pcolor(Ls,pas,GA,cmap='jet')#,vmax=0.1,vmin=0)
#ax.scatter(Ls,pas[minid],c='w',marker='D',s=10)
cb=fig.colorbar(im, ax=ax)
cb.set_label('err experiment')
ax.set_xlabel('L')
ax.set_ylabel('f')



In [None]:
g