In [1]:
import os
os.chdir("../")

In [2]:
%pip install -q flax

Note: you may need to restart the kernel to use updated packages.


In [1]:
import flax
import flax.linen as nn
from flax.core import unfreeze

from flax.training import train_state 

import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow.keras.datasets as tfds
import tensorflow as tf
from tensorflow.keras import layers

In [2]:
from scipy.ndimage import rotate

In [3]:
class LeNet(nn.Module):
    
    @nn.compact
    def __call__(self, x, deterministic):
        
        #1 layer
        x = nn.Conv(features=6, kernel_size=(5,5), name="conv2d_1")(x)
        # x = nn.tanh(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2))
        
        #2 layer
        x = nn.Conv(features=16, kernel_size=(5,5), name="conv2d_2")(x)
        # x = nn.tanh(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2))
        
        #3 layer
        x = nn.Conv(features=120, kernel_size=(5,5), name="conv2d_3")(x)
        # x = nn.tanh(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2))
        
        x = x.reshape((x.shape[0],-1)) #flatten
        #4 fully connected layer
        x = nn.Dense(84, name="Dense_1_84")(x)
        # x = nn.tanh(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=0.5, deterministic= deterministic)(x)

        #5 fully connected layer
        x = nn.Dense(10, name="Dense_2_10")(x)
        # x = nn.softmax(x)
        
        return x


In [4]:
def get_datasets():
    train_dataset, test_dataset = tfds.mnist.load_data()
    X_train, y_train = train_dataset
    x_test, y_test = test_dataset
    
    X_train = jnp.array(X_train/255)
    x_test = jnp.array(x_test/255)
    
    X_train = X_train[..., jnp.newaxis]
    x_test = x_test[..., jnp.newaxis]
    
    return [X_train, y_train], [x_test, y_test]


In [5]:
def cross_entropy_loss(params, x, y, rng, deterministic=False):
    labels_onehot = jax.nn.one_hot(y, num_classes=10)
    logits = model.apply(params, x, deterministic=deterministic, rngs={"dropout": rng})
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean(), logits

In [6]:
def compute_accuracy(params, x, y, rng):
    loss, logits = cross_entropy_loss(params, x, y, rng, deterministic=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == y)
    return accuracy.item()

In [7]:
def fit(params, X, Y, batch_size, learning_rate=0.1, epochs=100, verbose=False):
    # opt = optax.adamw(learning_rate=learning_rate)
    opt = optax.sgd(learning_rate, 0.0001)
    opt_state = opt.init(params)
    
    lg_fn = jax.value_and_grad(cross_entropy_loss, has_aux=True)
    
    rng, _ = jax.random.split(jax.random.PRNGKey(0))
    losses = []

    train_ds_size = len(X)
    steps_per_epoch = train_ds_size // batch_size

    for i in range(epochs):
        rng , _ = jax.random.split(rng)
        @jax.jit
        def one_step(params, x, y, opt_state):
            (loss_val, logits), grads = lg_fn(params, x, y, rng)
            updates, opt_state = opt.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss_val, logits

        # _ , rng = jax.random.split(rng)
        perms = jax.random.permutation(rng, train_ds_size)
        perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
        perms = perms.reshape((steps_per_epoch, batch_size))
        accuracy = []

        for perm in perms:
          x, y = X[perm], Y[perm]
          params, opt_state, loss_val, logits =  one_step(params, x, y, opt_state)
          accuracy.append(compute_accuracy(params, x, y, rng))

        if verbose and i % (epochs/10) == 0:
            print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (i, loss_val, jnp.mean(jnp.array(accuracy))))
    return params,losses


In [8]:
train_ds, test_ds = get_datasets()
X_train, y_train = train_ds
x_test, y_test = test_ds



In [9]:
# batch_size = 32
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
model = LeNet()
# model = CNN()
params = model.init(init_rng, jnp.ones([1, 28, 28, 1]), deterministic=True)

In [12]:
params, losses = fit(params, X_train, y_train, 64, epochs = 10, verbose=True)

train epoch: 0, loss: 0.0997, accuracy: 0.96
train epoch: 1, loss: 0.0081, accuracy: 1.00


KeyboardInterrupt: 

In [13]:
params = jnp.load("params.npy")[0]

In [14]:
params

'params'

In [None]:
one_img = X_train[y_train == 1][0]

In [None]:
one_img = rotate(one_img, 30, reshape=False)

In [None]:
images = []
for i in range(12):
#     plt.imshow(rotate(one_img, 6*i, reshape=False))
    images.append(jnp.array(rotate(one_img, 6*i, reshape=False)))
#     plt.show()

In [None]:
images = jnp.array(images)
labels = model.apply(params, images, deterministic=True)

In [None]:
plt.