In [2]:
from MiniTorch.nets.base import Net
from MiniTorch.nets.layers import Conv2D, Linear, MaxPool2d, PReLU, ReLU, SoftMax, Flatten, Tanh, Sigmoid
from MiniTorch.optimizers import SGD
from MiniTorch.losses import CCE, MSE
import matplotlib.pyplot as plt
import jax
import time
from MiniTorch.plotutils import show_conv_out
import jax.numpy as jnp

In [1]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist['data'], mnist['target']

In [3]:
X = X.to_numpy()
y = y.to_numpy()
n_samples = X.shape[0]
image_size = (28, 28)
X_images = X.reshape(n_samples, 28,28)

In [4]:
def get_training_data(X,Y,train_size):
    X_tra = X[:train_size]
    Y_tra = Y[:train_size]
    Y_tra = [[1 if j == int(i) else 0 for j in range(10)] for i in Y_tra]
    Y_tra = jnp.array(Y_tra)
    X_tra = jnp.reshape(jnp.array(X_tra),(train_size, 28,28))
    X_tra = jnp.expand_dims(X_tra,axis = 1)
    X_tra = X_tra.astype('float32')/255.
    return X_tra,Y_tra
X_tra, Y_tra = get_training_data(X,y,10000)

In [8]:
crit = CCE()

In [14]:
net_3 = Net(
    [
        Conv2D(1,4,20, accumulate_grad_norm=True, accumulate_params=True, initialization='xavier'),
        MaxPool2d(2,2),
        Flatten(),
        Linear(2880, 50, accumulate_grad_norm=True,accumulate_parameters=True,initialization='xavier'),
        Linear(50, 50, accumulate_grad_norm=True,accumulate_parameters=True,initialization='xavier'),
        Sigmoid(),
        Linear(50, 10,accumulate_grad_norm=True,accumulate_parameters=True,initialization='xavier'),
        SoftMax()
    ],
    reproducibility_key=20
)

In [6]:
X_example, y_example = jnp.expand_dims(X_tra[0], axis = 0), jnp.expand_dims(Y_tra[0], axis = 0)

In [15]:
def get_top_1_eigen_value(model, data_sample, crit):
    x,y = data_sample
    out = model.forward(x)
    los = crit.loss(out,y)
    ini_grad = crit.backward(los)
    model.backward(ini_grad)

In [16]:
get_top_1_eigen_value(net_3,(X_example,y_example), crit)

{'dL_dW': Array([[ 0.02608397,  0.02804006,  0.02075754,  0.1542476 ,  0.03187599,
         -0.436037  ,  0.02418122,  0.05980736,  0.02364981,  0.06739344],
        [ 0.02641419,  0.02839504,  0.02102032,  0.15620033,  0.03227954,
         -0.4415571 ,  0.02448735,  0.0605645 ,  0.02394921,  0.06824663],
        [ 0.02811357,  0.03022186,  0.02237269,  0.16624963,  0.03435627,
         -0.46996513,  0.02606276,  0.06446099,  0.02549001,  0.07263735],
        [ 0.02624981,  0.02821833,  0.02088951,  0.15522826,  0.03207865,
         -0.43880922,  0.02433496,  0.0601876 ,  0.02380017,  0.06782192],
        [ 0.03126592,  0.03361061,  0.02488132,  0.18489106,  0.03820861,
         -0.5226619 ,  0.02898516,  0.07168894,  0.02834818,  0.08078211],
        [ 0.02899126,  0.03116537,  0.02307115,  0.17143986,  0.03542886,
         -0.48463717,  0.02687643,  0.06647342,  0.02628579,  0.07490505],
        [ 0.03133131,  0.0336809 ,  0.02493335,  0.18527773,  0.03828852,
         -0.52375495,  