# Polar model / Trained on MNIST-ROT / Tested on MNIST-ROT

In [1]:
from tensorflow.keras import datasets, layers, models

import cv2 as cv
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import os
import matplotlib.pyplot as plt
import time

In [2]:
# To run on GPU, can be omitted for CPU only
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
# Function to get vertical cylinder effect - mentioned in paper
def padImage(image, pixels=5):
    bottom = image[-pixels:]
    top = image[:pixels]

    img = np.insert(image, 0, bottom, 0)
    img = np.insert(img, len(img), top, 0)
    #img = np.insert(img, [0], [0] * pixels, 1)
    #img = np.insert(img, [-1], [0] * pixels, 1)
    return img

### Download amat files here: 
http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip

## Load Dataset

In [4]:
# We assume amat files are in root
train_valid = np.loadtxt('mnist_rotation_train.amat')
test = np.loadtxt('mnist_rotation_test.amat')

X_train, y_train = train_valid[:,:-1], train_valid[:,-1]
X_valid, y_valid = X_train[-2000:].astype(np.float32), y_train[-2000:].astype(np.int32)
train_images, train_labels = X_train[:10000].astype(np.float32), y_train[:10000].astype(np.int32)
test_images,  test_labels  = test[:,:-1].astype(np.float32), test[:,-1].astype(np.int32)

# reshape to add alpha channel
train_images = np.reshape(train_images, (-1, 28, 28))
test_images = np.reshape(test_images, (-1, 28, 28))

## Prepare Dataset

In [5]:
### THE PAD HAS TO BE DONE IN THE
### POLAR SPACE

# 20 is the ceiling of (14 * sqrt(2)) - mentioned in paper
X_train_polar = [cv.linearPolar(x, tuple(np.array(x.shape)/2), 20, cv.WARP_FILL_OUTLIERS) for x in train_images]
X_train_polar = [padImage(x, pixels=5) for x in X_train_polar]
X_train_polar = np.array(X_train_polar)[...,None]

X_test_polar = [cv.linearPolar(x, tuple(np.array(x.shape)/2), 20, cv.WARP_FILL_OUTLIERS) for x in test_images]
X_test_polar = [padImage(x, pixels=5) for x in X_test_polar]
X_test_polar = np.array(X_test_polar)[...,None]

In [6]:
X_test_polar.shape

(50000, 38, 28, 1)

In [7]:
model = models.Sequential()
model.add(layers.Input(shape=X_train_polar.shape[1:]))
model.add(layers.Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D())
model.add(layers.LayerNormalization(axis=-1, epsilon=0.001, center=True, scale=True))
model.add(layers.Dropout(rate=0.5))
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D())
model.add(layers.LayerNormalization(axis=-1, epsilon=0.001, center=True, scale=True))
model.add(layers.Dropout(rate=0.5))
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))

model.add(layers.GlobalMaxPooling2D())
model.add(layers.Dense(64, activation='linear'))
model.add(layers.Activation('relu'))
model.add(layers.Dense(10))

In [8]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 38, 28, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 19, 14, 32)        0         
_________________________________________________________________
layer_normalization (LayerNo (None, 19, 14, 32)        64        
_________________________________________________________________
dropout (Dropout)            (None, 19, 14, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 19, 14, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 9, 7, 64)          0         
_________________________________________________________________
layer_normalization_1 (Layer (None, 9, 7, 64)          1

## Model Training

In [9]:
name = 'polar_ROT_MNIST'
checkpoint = tf.keras.callbacks.ModelCheckpoint(name + '.h5', verbose=1, save_best_only=True, monitor='val_accuracy', mode='max')

opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'], )
model.fit(X_train_polar, train_labels, batch_size=32, epochs=65, 
          validation_data=(X_test_polar, test_labels),
          callbacks=[checkpoint])

Epoch 1/65

Epoch 00001: val_accuracy improved from -inf to 0.33046, saving model to polar_ROT_MNIST.h5
Epoch 2/65

Epoch 00002: val_accuracy improved from 0.33046 to 0.55720, saving model to polar_ROT_MNIST.h5
Epoch 3/65

Epoch 00003: val_accuracy improved from 0.55720 to 0.70222, saving model to polar_ROT_MNIST.h5
Epoch 4/65

Epoch 00004: val_accuracy improved from 0.70222 to 0.76604, saving model to polar_ROT_MNIST.h5
Epoch 5/65

Epoch 00005: val_accuracy improved from 0.76604 to 0.81984, saving model to polar_ROT_MNIST.h5
Epoch 6/65

Epoch 00006: val_accuracy improved from 0.81984 to 0.84998, saving model to polar_ROT_MNIST.h5
Epoch 7/65

Epoch 00007: val_accuracy improved from 0.84998 to 0.86980, saving model to polar_ROT_MNIST.h5
Epoch 8/65

Epoch 00008: val_accuracy improved from 0.86980 to 0.88340, saving model to polar_ROT_MNIST.h5
Epoch 9/65

Epoch 00009: val_accuracy improved from 0.88340 to 0.89594, saving model to polar_ROT_MNIST.h5
Epoch 10/65

Epoch 00010: val_accuracy i

<tensorflow.python.keras.callbacks.History at 0x1c470c754f0>

## Model Accuracies

In [15]:
# load best epoch model
m = tf.keras.models.load_model('polar_ROT_MNIST.h5')

# Plain accuracy
accuracy = np.mean(m.predict(X_test_polar).argmax(axis=1) == test_labels)*100

print(accuracy)

96.242


## Prediction time for test set

In [11]:
times = []

for _ in range(0, 50):
    start = time.time()

    model.predict(X_test_polar)

    end = time.time()
    times.append(end - start)

print(np.mean(times))



2.069877095222473


In [12]:
np.mean(times) / len(X_test_polar)

4.139754190444946e-05

In [13]:
!python --version

Python 3.8.5
