In [1]:
import jax
import jax.numpy as np
import numpy as onp
from jax import random, jit, vmap, grad
from jax.experimental import optimizers,stax
from jax.experimental.stax import Dense, Conv, Relu, MaxPool, Flatten, Softmax,BatchNorm,Dropout,AvgPool
from jax.experimental.optimizers import optimizer, make_schedule, exponential_decay
#from jax.ops import index, index_update

import itertools
from functools import partial
from tqdm import trange
from torch.utils import data
import matplotlib.pyplot as plt

from scipy.integrate import odeint

from jax.nn import sigmoid,relu,log_sigmoid,one_hot
from jax.lax import scan
import h5py
from sklearn.metrics import confusion_matrix



In [2]:
import functools
import operator as op

from jax import lax
from jax import random
import jax.numpy as jnp

from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
                    leaky_relu, selu, gelu)
from jax.nn.initializers import glorot_normal, normal, ones, zeros

In [3]:
def plot_confusion_matrix(cm, classes,
                          normalize=True,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm/ np.sum(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2%' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

In [4]:
from keras.datasets import mnist
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [5]:
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rescale=None,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
zca_whitening = True,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=False,
fill_mode='nearest',
)



In [6]:
class DataGenerator(data.Dataset):
    def __init__(self, images, labels, 
                 batch_size=128, 
                 rng_key=random.PRNGKey(1234)):
        'Initialization'
        self.images = images
        self.labels = labels
        self.N = labels.shape[0]
        self.batch_size = batch_size
        self.key = rng_key

    @partial(jit, static_argnums=(0,))
    def __data_generation(self, key, images, labels):
        'Generates data containing batch_size samples'
        idx = random.choice(key, self.N, (self.batch_size,), replace=False)
        images = images[idx,...]
        labels = labels[idx,...]
        return images, labels

    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = random.split(self.key)
        images, labels = self.__data_generation(self.key, self.images, self.labels)
        return images, labels

In [19]:
x_train=np.asarray(x_train)
x_test=np.asarray(x_test)
# x_train=x_train/255.
# x_test=x_test/255.
x_train=x_train
x_test=x_test

In [8]:
#RGBtogray=np.array([0.299,0.587,0.114])
#x_train_gray = np.dot(x_train[:,:,:,:3], RGBtogray)
#x_test_gray = np.dot(x_test[:,:,:,:3], RGBtogray)


#x_train_gray = x_train_gray.reshape(-1,32,32,1)
#x_test_gray = x_test_gray.reshape(-1,32,32,1)


In [9]:
num_classes=y_train.max()+1
y_train=y_train.flatten()
y_train=one_hot(y_train,num_classes)

print(x_train.shape)
print(y_train.shape)

train_dataset = DataGenerator(x_train, y_train, batch_size=256)

(60000, 28, 28)
(60000, 10)


In [10]:
num_classes=y_test.max()+1
y_test=y_test.flatten()
y_test=one_hot(y_test,num_classes)
test_dataset = DataGenerator(x_test, y_test, batch_size=256)

In [11]:
print(x_test.shape)
print(y_test.shape)

(10000, 28, 28)
(10000, 10)


In [12]:
# Architecture
def VGG16():
    init, apply = stax.serial(Conv(64, (3, 3),padding="SAME"), #1
                              Relu, 
                              BatchNorm(),
                              Dropout(0.7,mode='train'),
                              Conv(64, (3, 3),padding="SAME"), #2
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(128, (3,3),padding="SAME"), #3
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(128, (3,3),padding="SAME"), #4
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(256, (3,3),padding="SAME"), #5
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(256, (3,3),padding="SAME"), #6
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(256, (3,3),padding="SAME"), #7
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(512, (3,3),padding="SAME"), #8
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(512, (3,3),padding="SAME"), #9
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(512, (3,3),padding="SAME"), #10
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(512, (3,3),padding="SAME"), #11
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(512, (3,3),padding="SAME"), #12
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='train'),
                              Conv(512, (3,3),padding="SAME"), #13
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Dropout(0.5,mode='train'),
                              Flatten,
                              Dense(512), 
                              Relu, 
                              BatchNorm(axis=(0,1)),
                              Dropout(0.5,mode='train'),
                              Dense(10), 
                              Softmax)
    return init, apply

In [13]:
# Architecture
def VGG16_test():
    init, apply = stax.serial(Conv(64, (3, 3),padding="SAME"), #1
                              Relu, 
                              BatchNorm(),
                              Dropout(0.7,mode='test'),
                              Conv(64, (3, 3),padding="SAME"), #2
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(128, (3,3),padding="SAME"), #3
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(128, (3,3),padding="SAME"), #4
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(256, (3,3),padding="SAME"), #5
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(256, (3,3),padding="SAME"), #6
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(256, (3,3),padding="SAME"), #7
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(512, (3,3),padding="SAME"), #8
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(512, (3,3),padding="SAME"), #9
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(512, (3,3),padding="SAME"), #10
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Conv(512, (3,3),padding="SAME"), #11
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(512, (3,3),padding="SAME"), #12
                              Relu,
                              BatchNorm(),
                              Dropout(0.6,mode='test'),
                              Conv(512, (3,3),padding="SAME"), #13
                              Relu,
                              BatchNorm(),
                              MaxPool((2, 2), (2, 2)),
                              Dropout(0.5,mode='test'),
                              Flatten,
                              Dense(512), 
                              Relu, 
                              BatchNorm(axis=(0,1)),
                              Dropout(0.5,mode='test'),
                              Dense(10), 
                              Softmax)
    return init, apply

In [17]:
class CNNclassifier:
    # Initialize the class
    def __init__(self, rng_key=random.PRNGKey(0)):
        # MLP init and apply functions
        self.net_init, self.net_apply = VGG16()
        self.test_init,self.test_apply=VGG16_test()
        # _, params = self.net_init(rng_key, (256,32,32,3))
        _, params = self.net_init(rng_key, (256,28,28))

        # Optimizer initialization and update functions
        lr = optimizers.exponential_decay(1e-3, decay_steps=100, decay_rate=0.999)
        self.opt_init, \
        self.opt_update, \
        self.get_params = optimizers.adam(lr)
        self.opt_state = self.opt_init(params)

        # Logger
        self.itercount = itertools.count()
        #for training
        self.loss_log = []
        self.acc_log=[]
        #for testing
        self.loss_log2=[]
        self.acc_log2=[]
    
    @partial(jit, static_argnums=(0,))
    def accuracy(self, params, batch):
        images, labels = batch
        outputs = self.predict(params, images)
        pred_class = np.argmax(outputs,1)
        true_class = np.argmax(labels,1)
        return np.sum((pred_class == true_class)) / images.shape[0]

    def loss(self, params, batch,it):
        images, labels = batch
        outputs = self.net_apply(params, images,rng=random.PRNGKey(it))
        loss = -labels*np.log(outputs+1e-7)
        return np.mean(loss)

    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch):
        params = self.get_params(opt_state)
        gradients = grad(self.loss)(params, batch, i)
        return self.opt_update(i, gradients, opt_state)

    def train(self, dataset,test_dataset, nIter = 10):
        data = iter(dataset)
        test_data=iter(test_dataset)
        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            # Run one gradient descent update
            batch = next(data)
            self.opt_state = self.step(next(self.itercount), self.opt_state, batch)  
            if it % 50 == 0:
                batch_test=next(test_data)
                # Logger
                params = self.get_params(self.opt_state)
                loss = self.loss(params, batch,it)
                self.loss_log.append(loss)
                acc = self.accuracy(params, batch)
                self.acc_log.append(acc)

                loss2 = self.loss(params, batch_test, it)
                acc2 = self.accuracy(params, batch_test)
                self.loss_log2.append(loss2)
                self.acc_log2.append(acc2)

                pbar.set_postfix({'Loss': loss,
                                  'Loss(test)': loss2,
                                  'Accuracy': acc,
                                  'Accuracy(test)': acc2})

    @partial(jit, static_argnums=(0,))
    def predict(self, params, inputs):
        outputs = self.test_apply(params, inputs,rng=random.PRNGKey(0))
        return outputs

In [20]:
# Initialize model
model = CNNclassifier()

IndexError: ignored

In [16]:
# Train model
model.train(train_dataset, test_dataset, nIter=20000)
opt_params = model.get_params(model.opt_state)
# Plot loss
plt.figure()
plt.plot(model.loss_log, lw=2)
plt.yscale('log')
plt.xlabel('Iter #')
plt.ylabel('Loss')

  0%|          | 0/20000 [00:01<?, ?it/s]


TypeError: ignored

In [None]:
test_init, test_apply=VGG16_test()

In [None]:
def predict(params, inputs):
  outputs = test_apply(params, inputs,rng=random.PRNGKey(0))
  return outputs

In [None]:
# Compute classification accuracy on the entire test data-set
@jit
def accuracy(params, batch):
    images, labels = batch
    outputs = predict(params, images)
    pred_class = np.argmax(outputs,1)
    true_class = np.argmax(labels,1)
    return np.sum((pred_class == true_class)) / images.shape[0]

acc = accuracy(opt_params, (x_test, y_test))
print('Classification accuracy: {}%'.format(100*acc))

In [None]:
plt.figure(dpi = 150)
plt.xlabel(r'Number of Iterations ($\times 50$)')
plt.ylabel('Cross Entropy Loss')
plt.semilogy(model.loss_log, label = 'Train')
plt.semilogy(model.loss_log2, label = 'Test')
plt.legend()

In [None]:
plt.figure(dpi = 150)
plt.xlabel(r'Number of Iterations ($\times 50$)')
plt.ylabel('Accuracy')
plt.semilogy(model.acc_log, label = 'Train')
plt.semilogy(model.acc_log2, label = 'Test')
plt.legend()

In [None]:
len(model.loss_log)

In [None]:
# calculate the number of parameters
def num_params(params):
  num_params = 0
  for i in range(len(params)):
    tup = params[i]
    tup_len = 0
    for j in range(len(tup)):
      tup_j_len = len(params[i][j].ravel())
      tup_len += tup_j_len
    num_params += tup_len
  return num_params

In [None]:
print('Number of parameters: ', num_params(opt_params))

In [None]:
# save outputs
onp.save('vgg16_loss', model.loss_log)
onp.save('vgg16_test_loss', model.loss_log2)

In [None]:
# save acc
onp.save('vgg16_acc', model.acc_log)
onp.save('vgg16_test_acc', model.acc_log2)