In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential 
from tensorflow.keras.layers import Dense, Dropout, Lambda
import subprocess as sp
import numpy as np

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')


In [3]:
num_samples = x_train.shape[0]
num_classes = 10

In [4]:
BATCH_SIZE = 256
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
ds_train = ds_train.map(lambda x, y: (tf.reshape(x, [-1])/255, tf.one_hot(y, num_classes)))
ds_train = ds_train.batch(BATCH_SIZE).repeat().prefetch(1)

ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
ds_test = ds_test.map(lambda x, y: (tf.reshape(x, [-1])/255, tf.one_hot(y, num_classes)))
ds_test = ds_test.batch(BATCH_SIZE).repeat().prefetch(1)



In [5]:
tf.keras.backend.clear_session()

memory_usage = []
class MemoryCheck(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        mem = sp.check_output('nvidia-smi | grep python', shell=True).split()[-2].decode('utf-8')
        memory_usage.append(int(mem[:-3]))
        print(' ' + mem)

mem_check = MemoryCheck()


In [6]:
@tf.function
def cube_this(x):
    return x**3

In [7]:
from layers import Highway

model = Sequential()
model.add(Dense(512, activation = 'relu', input_shape = (784, )))
model.add(Dropout(0.2))
model.add(Highway())
model.add(Dense(64, activation = 'relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation = 'softmax'))
model.build()

In [8]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
highway (Highway)            (None, 512)               525312    
_________________________________________________________________
dense_1 (Dense)              (None, 64)                32832     
_________________________________________________________________
dropout_1 (Dropout)          (None, 64)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650       
Total params: 960,714
Trainable params: 960,714
Non-trainable params: 0
__________________________________________________

In [9]:
model.compile('RMSprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])

In [10]:
model.fit_generator(generator = ds_train, 
                    steps_per_epoch = int(np.ceil(num_samples /float(BATCH_SIZE))), 
                    validation_data = ds_test, 
                    validation_steps = int(np.ceil(x_test.shape[0] /float(BATCH_SIZE))),
                    
                    epochs = 10, callbacks = [mem_check])

Epoch 1/10
 911MiB
  1/235 [..............................] - ETA: 57s - loss: 2.3663 - accuracy: 0.0742 911MiB
  2/235 [..............................] - ETA: 36s - loss: 2.1958 - accuracy: 0.2148 911MiB
  3/235 [..............................] - ETA: 29s - loss: 2.0524 - accuracy: 0.2878 911MiB
  4/235 [..............................] - ETA: 25s - loss: 1.9125 - accuracy: 0.3535 911MiB
  5/235 [..............................] - ETA: 23s - loss: 1.8717 - accuracy: 0.3758 911MiB
  6/235 [..............................] - ETA: 22s - loss: 1.7904 - accuracy: 0.4062 911MiB
  7/235 [..............................] - ETA: 21s - loss: 1.6974 - accuracy: 0.4448 911MiB
  8/235 [>.............................] - ETA: 20s - loss: 1.6180 - accuracy: 0.4785 911MiB
  9/235 [>.............................] - ETA: 19s - loss: 1.5393 - accuracy: 0.5056 911MiB
 10/235 [>.............................] - ETA: 19s - loss: 1.4796 - accuracy: 0.5266 911MiB
 11/235 [>.............................] - ETA: 18s



Epoch 2/10
 911MiB
  1/235 [..............................] - ETA: 15s - loss: 0.3194 - accuracy: 0.9219 911MiB
  2/235 [..............................] - ETA: 16s - loss: 0.2397 - accuracy: 0.9395 911MiB
  3/235 [..............................] - ETA: 16s - loss: 0.2302 - accuracy: 0.9440 911MiB
  4/235 [..............................] - ETA: 16s - loss: 0.2286 - accuracy: 0.9424 911MiB
  5/235 [..............................] - ETA: 16s - loss: 0.2466 - accuracy: 0.9344 911MiB
  6/235 [..............................] - ETA: 16s - loss: 0.2362 - accuracy: 0.9382 911MiB
  7/235 [..............................] - ETA: 16s - loss: 0.2236 - accuracy: 0.9414 911MiB
  8/235 [>.............................] - ETA: 16s - loss: 0.2119 - accuracy: 0.9438 911MiB
  9/235 [>.............................] - ETA: 16s - loss: 0.2028 - accuracy: 0.9462 911MiB
 10/235 [>.............................] - ETA: 16s - loss: 0.1934 - accuracy: 0.9488 911MiB
 11/235 [>.............................] - ETA: 15s

 29/235 [==>...........................] - ETA: 15s - loss: 0.1680 - accuracy: 0.9515 911MiB
 30/235 [==>...........................] - ETA: 14s - loss: 0.1693 - accuracy: 0.9509 911MiB
 31/235 [==>...........................] - ETA: 14s - loss: 0.1700 - accuracy: 0.9507 911MiB
 32/235 [===>..........................] - ETA: 14s - loss: 0.1704 - accuracy: 0.9501 911MiB
 33/235 [===>..........................] - ETA: 14s - loss: 0.1707 - accuracy: 0.9500 911MiB
 34/235 [===>..........................] - ETA: 14s - loss: 0.1711 - accuracy: 0.9498 911MiB
 35/235 [===>..........................] - ETA: 14s - loss: 0.1767 - accuracy: 0.9482 911MiB
 36/235 [===>..........................] - ETA: 14s - loss: 0.1751 - accuracy: 0.9487 911MiB
 37/235 [===>..........................] - ETA: 14s - loss: 0.1764 - accuracy: 0.9481 911MiB
 38/235 [===>..........................] - ETA: 14s - loss: 0.1754 - accuracy: 0.9482 911MiB
 39/235 [===>..........................] - ETA: 14s - loss: 0.1734 - a

KeyboardInterrupt: 