In [129]:
import numpy as np
from sklearn import datasets
import jax
from jax import numpy as jnp
import random

In [130]:
ds = datasets.load_digits()
X, y = ds['images'], ds['target']
X = X / 255

In [131]:
X.shape

(1797, 8, 8)

In [132]:
parameters = (
	np.random.normal(scale=1 / 256, size=(64 + 10 + 1, 256)).T,  #   256
	np.random.normal(scale=1 / 256, size=(256 + 1, 256)).T,  #   256
	np.random.normal(scale=1 / 256, size=(256 + 1, 256)).T,  #   256
	np.random.normal(scale=1 / 256, size=(256 + 1, 256)).T,  #   256
	np.random.normal(scale=1 / 256, size=(256 + 1, 128)).T,  # + 128
	# ------
	np.random.normal(size=(10, 1152 + 1))  #  1152
)


def layer_infer(layer, x):
	x = layer @ jnp.concatenate((x, np.ones(shape=1)))
	x = jnp.log1p(jnp.exp(x))  # softplus
	x = x / np.linalg.norm(x)  # layer norm
	return x


def last_layer_infer(layer, x):
	x = layer @ jnp.concatenate((x, np.ones(shape=1)))

	exp_x = jnp.exp(x)
	return exp_x / np.sum(exp_x)  # softmax


def inference(x, layers):
	layer0, layer1, layer2, layer3, layer4, layerFinal = layers

	x0 = layer_infer(layer0, x)
	x1 = layer_infer(layer1, x0)
	x2 = layer_infer(layer2, x1)
	x3 = layer_infer(layer3, x2)
	x4 = layer_infer(layer4, x3)

	return last_layer_infer(layerFinal, np.concatenate((x0, x1, x2, x3, x4)))


inference(np.concatenate((X[0].flatten(), np.zeros(10))), parameters)

Array([1.1384060e-03, 8.0636370e-01, 5.8803633e-02, 2.2045203e-02,
       9.5193041e-03, 4.3360114e-02, 1.1978690e-03, 4.7167134e-02,
       9.9833654e-03, 4.2123359e-04], dtype=float32)

In [133]:
def layer_infer_0(layer, x):
	x = layer @ jnp.concatenate((x, np.ones(shape=1)))
	x = jnp.log1p(jnp.exp(x))  # softplus
	return x


def layer_train(layer, x, positive):
	dlayer = jax.grad(lambda l, x: jnp.sum(layer_infer_0(l, x)) ** 2)(layer, x)
	if positive:
		dlayer *= -1
	return dlayer, layer_infer(layer, x)


def cross_entropy_loss(y_pred, y):
	return jnp.sum(-y * jnp.log(y_pred))


def last_layer_train(layer, x, y, positive):
	return jax.grad(lambda l, x, y: cross_entropy_loss(last_layer_infer(l, x), y))(layer, x,
	                                                                               y) if positive else jnp.zeros_like(
		layer), last_layer_infer(layer, x)


def train(x, y, layers, positive: bool):
	layer0, layer1, layer2, layer3, layer4, layerFinal = layers
	dl0, x0 = layer_train(layer0, x, positive)
	dl1, x1 = layer_train(layer1, x0, positive)
	dl2, x2 = layer_train(layer2, x1, positive)
	dl3, x3 = layer_train(layer3, x2, positive)
	dl4, x4 = layer_train(layer4, x3, positive)

	dlf, y = last_layer_train(layerFinal, jnp.concatenate((x0, x1, x2, x3, x4)), y, positive)
	return (dl0, dl1, dl2, dl3, dl4, dlf), y


train(np.concatenate((X[0].flatten(), np.zeros(10))), np.zeros(10), parameters, positive=True)

((Array([[  -0.       ,   -0.       ,   -3.4790595, ...,   -0.       ,
            -0.       , -177.43202  ],
         [  -0.       ,   -0.       ,   -3.4836123, ...,   -0.       ,
            -0.       , -177.66422  ],
         [  -0.       ,   -0.       ,   -3.4757967, ...,   -0.       ,
            -0.       , -177.26562  ],
         ...,
         [  -0.       ,   -0.       ,   -3.472893 , ...,   -0.       ,
            -0.       , -177.11754  ],
         [  -0.       ,   -0.       ,   -3.4796553, ...,   -0.       ,
            -0.       , -177.46242  ],
         [  -0.       ,   -0.       ,   -3.471619 , ...,   -0.       ,
            -0.       , -177.05255  ]], dtype=float32),
  Array([[ -11.112507 ,  -11.133505 ,  -11.097481 , ...,  -11.115258 ,
           -11.078258 , -177.75125  ],
         [ -11.127423 ,  -11.148449 ,  -11.112376 , ...,  -11.1301775,
           -11.093127 , -177.98984  ],
         [ -11.120061 ,  -11.141073 ,  -11.105024 , ...,  -11.122814 ,
           -11.085

In [134]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:

theta = parameters
for i in range(10):
	print(f"Epoch {i} start!")
	total_loss = 0
	for x, y in zip(X_test, y_test):
		total_loss = cross_entropy_loss(inference(np.concatenate((x.flatten(), np.ones(10) / 10)), theta), np.eye(10)[y])
	print(f"Total Loss {total_loss}")

	for x, y in zip(X_train, y_train):
		positive = np.concatenate((x.flatten(), np.eye(10)[y]))
		neg_y = y
		while neg_y != y:
			neg_y = random.randint(0, 9)
		negative = np.concatenate((x.flatten(), np.eye(10)[neg_y]))

		dtheta_p, _ = train(positive, np.eye(10)[y], theta, positive=True)
		dtheta_n, _ = train(negative, np.eye(10)[neg_y], theta, positive=False)
		theta = tuple(theta - dtp * .0001 - dtn * .0001 for theta, dtp, dtn in zip(theta, dtheta_p, dtheta_n))

	print(f"Epoch {i} end!")



Epoch 0 start!
Total Loss 7.771260738372803
Epoch 0 end!
Epoch 1 start!
Total Loss 7.289566993713379
Epoch 1 end!
Epoch 2 start!
Total Loss 6.918318271636963
Epoch 2 end!
Epoch 3 start!
Total Loss 6.630514621734619
Epoch 3 end!
Epoch 4 start!
Total Loss 6.39785099029541


In [None]:
X[0], y[0]

In [None]:
# It works! yay