# Import modules

In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

# Custom keras model


In [118]:
class MLP(keras.Sequential):
    def __init__(self, n_features, flank):
        super(MLP, self).__init__()
        self.flank = flank
        self.layer1 = keras.layers.Dense(units=100,input_shape=[(2*flank+1)*n_features],activation='sigmoid')
        self.layer2 = keras.layers.Dense(units=3,activation='softmax')
        self.padded_x = tf.Variable(1., shape=tf.TensorShape(None))

    def call(self,x):
        x = self._window(x)
        x = self.layer1(x)
        return self.layer2(x)

    def _window(self,x):
        batch_size = x.shape[0]
        ta = tf.TensorArray(tf.float32, size=batch_size)
        dim = (2*self.flank+1)*x.shape[2]

        for i in range(batch_size):
            n = x[i].shape[0]
            padded_x = tf.pad(x[i], ((self.flank,self.flank),(0,0)), 'constant')            
            self.padded_x = padded_x
            ta2 = tf.TensorArray(tf.float32, size=n)
            for j in range(n):
                window = self.padded_x[j:j+2*self.flank+1]
                vector = tf.reshape(window,(-1,))
                ta2.write(j, vector)
            x_w = ta2.stack()
            ta.write(i, x_w)

        return ta.stack()
    
    def train_step(self,data):
        x,y = data
        with tf.GradientTape() as tape:
            pred = self(x)
            loss = self.compiled_loss(y, pred) 
        grad = tape.gradient(loss,self.trainable_variables)
        self.optimizer.apply_gradients(zip(grad,self.trainable_variables))

        # self.compiled_metrics.update_state(y, pred)
        # metr = self.compiled_metrics.result()
        # self.compiled_metrics.reset_states()

        return {'loss' : loss}#, 'metric' : metr}

In [80]:
class MyClassifier(keras.Model):
    global IS_DEBUG
    def __init__(self):
        super(MyClassifier,self).__init__()
        self.NN1 = MLP(n_features=24, flank=8)
        self.NN2 = MLP(n_features= 3, flank=9)
        self.NN3 = MLP(n_features=20, flank=8)
        self.NN4 = MLP(n_features= 3, flank=9)
        
    def compile(self,*args,**kwargs):
        super(MyClassifier,self).compile(*args,**kwargs)
        self.NN1.compile(*args,**kwargs)
        self.NN2.compile(*args,**kwargs)
        self.NN3.compile(*args,**kwargs)
        self.NN4.compile(*args,**kwargs)

    def call(self,x):
        x1 = x[:,:,:24]
        x2 = x[:,:,24:44]
        x1 = self.NN1(x1)
        x1 = self.NN2(x1)
        x2 = self.NN3(x2)
        x2 = self.NN4(x2)
        return x1+x2//2

    def train_step(self,data):
        x,y = data
        x1 = x[:,:,:24]
        x2 = x[:,:,24:44]

        d1 = self.NN1.train_step((x1,y))
        x1 = self.NN1(x1)
        d2 = self.NN2.train_step((x1,y))

        d3 = self.NN3.train_step((x2,y))
        x2 = self.NN3(x2)
        d4 = self.NN4.train_step((x2,y))

        return {
            'loss_1' : d1['loss'],
            'loss_2' : d2['loss'],
            'loss_3' : d3['loss'],
            'loss_4' : d4['loss'],
            # 'metr_1' : d1['metric'],
            # 'metr_2' : d2['metric'],
            # 'metr_3' : d3['metric'],
            # 'metr_4' : d4['metric'],            
        }

    # def train_step(self,data):
    #     x,y = data
    #     x1 = x[:,:,:24]
    #     x2 = x[:,:,24:44]

    #     with tf.GradientTape() as tape:
    #         pred1 = self.NN1(x1)
    #         loss1 = self.NN1.compiled_loss(y, pred1) 
    #     grad = tape.gradient(loss1, self.NN1.trainable_variables)
    #     self.NN1.optimizer.apply_gradients(zip(grad,self.NN1.trainable_variables))

    #     print('NN1 metrics: ',self.NN1.compiled_metrics)
    #     if self.NN1.compiled_metrics:
    #         self.NN1.compiled_metrics.update_state(y,pred1)
    #         metr1 = float(self.NN1.compiled_metrics.result())
    #         self.NN1.compiled_metrics.reset_states()

    #     with tf.GradientTape() as tape:
    #         pred2 = self.NN2(pred1)
    #         loss2 = self.NN2.compiled_loss(y, pred2)
    #     grad = tape.gradient(loss2, self.NN2.trainable_variables)
    #     self.NN2.optimizer.apply_gradients(zip(grad,self.NN2.trainable_variables))

    #     if self.NN2.compiled_metrics:
    #         self.NN2.compiled_metrics.update_state(y,pred2)
    #         metr2 = float(self.NN2.compiled_metrics.result())
    #         self.NN2.compiled_metrics.reset_states()

    #     with tf.GradientTape() as tape:
    #         pred3 = self.NN3(x2)
    #         loss3 = self.NN3.compiled_loss(y, pred3)
    #     grad = tape.gradient(loss3, self.NN3.trainable_variables)
    #     self.NN3.optimizer.apply_gradients(zip(grad,self.NN3.trainable_variables))

    #     if self.NN3.compiled_metrics:
    #         self.NN3.compiled_metrics.update_state(y,pred3)
    #         metr3 = float(self.NN3.compiled_metrics.result())
    #         self.NN3.compiled_metrics.reset_states()        

    #     with tf.GradientTape() as tape:
    #         pred4 = self.NN4(pred3)
    #         loss4 = self.NN4.compiled_loss(y, pred4)        
    #     grad = tape.gradient(loss4, self.NN4.trainable_variables)
    #     self.NN4.optimizer.apply_gradients(zip(grad,self.NN4.trainable_variables))

    #     if self.NN4.compiled_metrics:
    #         self.NN4.compiled_metrics.update_state(y,pred4)
    #         metr4 = float(self.NN4.compiled_metrics.result())
    #         self.NN4.compiled_metrics.reset_states()

    #     return {'loss_1' : loss1,
    #             'loss_2' : loss2, 
    #             'loss_3' : loss3, 
    #             'loss_4' : loss4}


# Generate dummy targets and predictors

In [7]:
predictors = np.random.uniform(0,1,size=(1155,100,44))
targets_raw = np.random.randint(low=0,high=2,size=(1155,100,))
targets = np.eye(3)[targets_raw]


# Main procedure (working)

In [124]:
IS_DEBUG=1
clf = MyClassifier()
loss = keras.losses.CategoricalCrossentropy(reduction=keras.losses.Reduction.NONE)
optimizer = keras.optimizers.SGD(learning_rate=1e-2)

clf.compile(loss=loss,optimizer=optimizer,metrics=[keras.metrics.CategoricalAccuracy()])

clf.fit(predictors, targets, batch_size = 5, epochs=2)


Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7feff2325bd0>

In [126]:
clf.metrics_names

[]