In [None]:
class MyModel(Model):
    def __init__(self):
        super(MyModel,self).__init__()
        self.dense = Dense(1024,activation='relu')
        self.dropout = Dropout(0.2)
        self.logits = Dense(1,activation='sigmoid')
        
    def call(self,inputs):
        x = self.dense(inputs)
        x = self.dropout(x)
        out = self.logits(x)
        return out
    
subclassing_model = MyModel()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.BinarayCrossentropy(from_logits=False)

In [None]:
from sklearn.utils import gen_batches, shuffle

In [None]:
total_epoch = 10
batch_size = 128

In [None]:
train_batches = list(gen_batches(len(train_x), batch_size))

In [None]:
train_loss = tf.keras.metrics.Mean()
valid_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.BinaryAccuracy()
valid_acc = tf.keras.metrics.BinaryAccuracy()

train_auc = tf.keras.metrics.AUC()
valid_auc = tf.keras.metrics.AUC()

In [None]:
for epoch in range(total_epoch):
    
    train_loss.reset_state()
    valid_loss.reset_state()
    
    ####################################################################################################
    # training
    (shuffle_x, shuffle_y) = shuffle(train_x, train_y)
    
    for batch in train_batches:
        
        batch_x = shuffle_x[batch]
        batch_y = shuffle_y[batch]
        
        with tf.GradientTape() as tape:
            logits = subclassing_model(batch_x)
            loss = loss_fn(batch_y, logits)

        gradients = tape.gradient(loss, subclassing_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, subclassing_model.trainable_variables))
        
        train_loss.update_state(loss)
        train_acc.update_state(batch_y, logits)
        train_auc.update_state(batch_y, logits)
        
    ####################################################################################################
    # history
    logits = subclassing_model.predict(valid_x, verbose=False)
    loss = loss_fn(valid_y, logits)
    valid_loss.update_state(loss)
    valid_acc.update_state(valid_y, logits)
    valid_auc.update_state(valid_y, logits)
    
    msg = "epoch: {:>5d} - loss: {:>.5f} - accuracy: {:>.3%} - auc: {:>.3%} - val_loss: {:>.5f} - val_accuracy: {:>.3%} - val_auc: {:>.3%}"
    print(msg.format(epoch, 
                     train_loss.result().numpy(), train_acc.result().numpy(), train_auc.result().numpy(), 
                     valid_loss.result().numpy(), valid_acc.result().numpy(), valid_auc.result().numpy()))

In [None]:
subclassing_pred = subclassing_model.predict(valid_x)
subclassing_pred