In [2]:
from keras.datasets import mnist
from keras.engine.topology import Layer
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten, TimeDistributed
from keras.utils import to_categorical
from keras import optimizers

import t3f
import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
from keras.layers import Recurrent

# First, let's define a RNN Cell, as a layer subclass.

class TTMinimalRNNCell(Recurrent):
    counter = 0

    def __init__(self, row_dims, column_dims, tt_rank=2, init='glorot',
                 activation='relu', bias=True, bias_init=0.1, **kwargs):
        self.units = np.prod(column_dims)
        self.states = [None]
        
        self.tt_kernel_shape = [row_dims, column_dims]
        self.tt_recurrent_shape = [column_dims, column_dims]
        self.output_dim = np.prod(column_dims)
        self.tt_rank = tt_rank
        self.activation = activation
        self.bias = bias
        self.bias_init = bias_init
        self.init = init
        
        super(TTMinimalRNNCell, self).__init__(**kwargs)        

    def build(self, input_shape):
        # kernel
        initializer = t3f.glorot_initializer(self.tt_kernel_shape,
                                        tt_rank=self.tt_rank)
        name = 'tt_rnn_kernel_matrix_{}'.format(TTMinimalRNNCell.counter)
        self.kernel = t3f.get_variable(name, initializer=initializer)
        
        
        # recurrent
        initializer = t3f.glorot_initializer(self.tt_recurrent_shape,
                                        tt_rank=self.tt_rank)
        name = 'tt_rnn_recurrent_matrix_{}'.format(TTMinimalRNNCell.counter)
        self.recurrent_kernel = t3f.get_variable(name, initializer=initializer)
        
        self.trainable_weights = list(self.kernel.tt_cores) + list(self.recurrent_kernel.tt_cores)
        
        TTMinimalRNNCell.counter += 1
        self.built = True

    def step(self, inputs, states):
        prev_output = states[0]
        h = t3f.matmul(inputs, self.kernel)
        z = t3f.matmul(prev_output, self.recurrent_kernel)
        
        output = h + z
        return output, [output]

In [14]:
model = Sequential()
model.add(TimeDistributed(Flatten(), input_shape=(28, 28)))
model.add(TTMinimalRNNCell(row_dims=[4, 7], column_dims=[5, 5], tt_rank=4))
model.add(Dense(10))
model.add(Activation('softmax'))

In [15]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
time_distributed_2 (TimeDist (None, 28, 28)            0         
_________________________________________________________________
tt_minimal_rnn_cell_2 (TTMin (None, 25)                420       
_________________________________________________________________
dense_2 (Dense)              (None, 10)                260       
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0         
Total params: 680
Trainable params: 680
Non-trainable params: 0
_________________________________________________________________


In [16]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train / 127.5 - 1.0
x_test = x_test / 127.5 - 1.0

y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

In [17]:
optimizer = optimizers.Adam(lr=1e-3)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

In [18]:
model.fit(x_train, y_train, epochs=100, batch_size=64, validation_data=(x_test, y_test))

Train on 60000 samples, validate on 10000 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100


Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<keras.callbacks.History at 0x7f80aadbfe48>