In [17]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax.scipy as jsp

In [18]:
#Function to init layer of size nxm with key and scale
def random_layer_params(m,n,key,scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale*random.normal(b_key, (n,))

def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m,n,k) for m,n,k in zip(sizes[:-1], sizes[1:], keys)]

In [19]:
layer_sizes = [484, 512, 512, 47]
step_size = 0.01
num_epochs = 20
batch_size = 128
n_targets = 47
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [20]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0,x)

def predict(params2, kernel, image):
    activations = image
    activations = jsp.signal.convolve(activations.reshape(28,28), kernel, mode="valid")
    activations = activations.reshape(22*22).squeeze()
    
    
    for w, b in params2[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, finals_b = params2[-1]
    logits = jnp.dot(final_w, activations) + finals_b

    return logits - logsumexp(logits)

batched_predict = vmap(predict, in_axes=(None, None, 0))

def loss2(params, window, images, targets):
    preds = batched_predict(params, window, images)
    ret = -jnp.mean(preds * targets)
    return ret

@jit
def update(params, window, x , y):
    grads = grad(loss2, (0,1))(params, window, x, y)

    return (window - step_size * grads[1], [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads[0])])

In [21]:
#utility functions

def one_hot(x, k, dtype=jnp.float32):
    #creates one hot encoding of x of size k
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, window, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, window, images), axis=1)
    return jnp.mean(predicted_class == target_class)

In [22]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import EMNIST

def numpy_collate(batch):
    return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [23]:
# Define our dataset, using torch datasets
mnist_dataset = EMNIST('/tmp/emnist/', split="balanced", download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

In [24]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = EMNIST('/tmp/emnist/', split="balanced", download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

In [25]:
#Init "network"
params2 = init_network_params(layer_sizes, random.PRNGKey(0))

window = random.normal(random.PRNGKey(0), (7, 7))

In [26]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    window, params = update(params, window, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, window, train_images, train_labels)
  test_acc = accuracy(params, window, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

2023-10-19 14:13:44.558782: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.98GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-10-19 14:13:46.291657: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng28{k2=0,k3=0} for conv (f32[112800,1,22,22]{3,2,1,0}, u8[0]{0}) custom-call(f32[112800,1,28,28]{3,2,1,0}, f32[1,1,7,7]{3,2,1,0}), window={size=7x7}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} is taking a while...
2023-10-19 14:13:46.292633: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.001081868s
Trying algorithm eng28{k2=0,k3=0} for conv (f32[112800,1,22,22]{3,2,1,0}, u8[0]{0}) custom-call(f32[112800,1,28,28]{3,2,1,0}, f32[

Epoch 0 in 8.41 sec
Training set accuracy 0.6390248537063599
Test set accuracy 0.6165957450866699
Epoch 1 in 6.52 sec
Training set accuracy 0.7133333683013916
Test set accuracy 0.6851063966751099
Epoch 2 in 6.48 sec
Training set accuracy 0.7521986365318298
Test set accuracy 0.7189361453056335
Epoch 3 in 5.92 sec
Training set accuracy 0.777216374874115
Test set accuracy 0.7423403859138489
Epoch 4 in 5.97 sec
Training set accuracy 0.7957181334495544
Test set accuracy 0.7555850744247437
Epoch 5 in 6.19 sec
Training set accuracy 0.8101241588592529
Test set accuracy 0.7672340273857117
Epoch 6 in 6.36 sec
Training set accuracy 0.8209840655326843
Test set accuracy 0.774893581867218
Epoch 7 in 6.29 sec
Training set accuracy 0.8306560516357422
Test set accuracy 0.780904233455658
Epoch 8 in 6.17 sec
Training set accuracy 0.838129460811615
Test set accuracy 0.7847340106964111
Epoch 9 in 6.17 sec
Training set accuracy 0.8444504141807556
Test set accuracy 0.787872314453125
Epoch 10 in 6.15 sec
Trai