In [0]:
# https://arxiv.org/abs/1409.4842

import tensorflow as tf
from tensorflow.keras import layers

In [0]:
class Inception(tf.keras.Model):

  def __init__(self, f_1, f_3r, f_3, f_5r, f_5, f_p):
    super(Inception, self).__init__()
    self.conv1 = layers.Conv2D(f_1, 1, strides, 'same', activation='relu')
    self.conv3r = layers.Conv2D(f_3r, 1, 1, 'same', activation='relu')
    self.conv3 = layers.Conv2D(f_3, 3, 1, 'same', activation='relu')
    self.conv5r = layers.Conv2D(f_5r, 1, 1, 'same', activation='relu')
    self.conv5 = layers.Conv2D(f_5, 5, 1, 'same', activation='relu')
    self.convp = layers.Conv2D(f_p, 1, 1, 'same', activation='relu')
    self.pool = layers.MaxPool2D()
  
  def __call__(self, x, training=False, mask=None):
    y1 = self.conv1(x)
    y2_t = self.conv3r(x)
    y2 = self.conv3(y2_t)
    y3_t = self.conv5r(x)
    y3 = self.conv5(y3_t)
    y4_t = self.pool(x)
    y4 = self.convp(y4_t)
    return layers.concatenate([y1,y2 ,y3, y4])

class MiddleLoss(tf.keras.Model):
  def __init__(self):
    super(MiddleLoss, self).__init__(self)
    self.pool = layers.AveragePooling2D(5, 3)
    self.conv = layers.Conv2D(128, 1, 'same', activation='relu')
    self.flatten = layers.Flatten()
    self.dense1 = layers.Dense(1024, activation='relu')
    self.dropout = layers.Dropout(0.7)
    self.dense2 = layers.Dense(1000, 'softmax')
  def __call__(self, x, training=False, mask=None):
    x = self.pool(x)
    x = self.conv(x)
    x = self.flatten(x)
    x = self.dense1(x)
    x = self.dropout(x)
    x = self.dense2(x)

In [0]:
class GoogleNet(tf.keras.Model):

  def __init__(self):
    self.conv1 = layers.Conv2D(64, 7, 2, 'same', activation='relu')
    self.pool1 = layers.MaxPool2D(3, 2, 'same')
    self.conv2r = layers.Conv2D(64, 1, 1, 'same', activation='relu')
    self.conv2 = layers.Conv2D(192, 3, 1, 'same', activation='relu')
    self.pool2 = layers.MaxPool2D(3, 2, 'same')
    self.incep3a = Inception(64, 96, 128, 16, 32, 32)
    self.incep3b = Inception(128, 128, 192, 32, 96, 64)
    self.pool3 = layers.MaxPool2D(3, 2, 'same')
    self.incep4a = Inception(192, 96, 208, 16, 48, 64)
    self.loss1 = MiddleLoss()
    self.incep4b = Inception(160, 112, 224, 24, 64, 64)
    self.incep4c = Inception(128, 128, 256, 24, 64, 64 )
    self.incep4d = Inception(112, 144, 288, 32, 64, 64)
    self.loss2 = MiddleLoss()
    self.incep4e = Inception(256, 160, 320, 32, 128, 128)
    self.pool4 = layers.MaxPool2D(3, 2, 'same')
    self.incep5a = Inception(256, 160, 320, 32, 128, 128)
    self.incep5b = Inception(384, 192, 384, 48, 128, 128)
    self.pool5 = layers.AveragePooling2D(7, 1, 'same')
    self.dropout = layers.Dropout(0.4)
    self.flatten = layers.Flatten()
    self.dense = layers.Dense(1000, 'softmax')
  
  def __call__(self, x, training=False, mask=None):
    x = self.conv1(x)
    x = self.pool1(x)
    x = layers.BatchNormalization()(x, training=training)

    x = self.conv2r(x)
    x = self.conv2(x)
    x = layers.BatchNormalization()(x, training=training)
    x = self.pool2(x)

    x = self.incep3a(x)
    x = self.incep3b(x)
    x = self.pool3(x)

    x = self.incep4a(x)
    y1 = self.loss1(x)
    x = self.incep4b(x)
    x = self.incep4c(x)
    x = self.incep4d(x)
    y2 = self.loss2(x)
    x = self.incep4e(x)
    x = self.pool4(x)

    x = self.incep5a(x)
    x = self.incep5b(x)
    x = self.pool5(x)
    
    x = self.flatten(x)
    x = self.dropout(x)
    y3 = self.dense(x)
    return [y1, y2, y3]

In [0]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_accuracy):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    loss = loss_object(labels, predictions[0])
    loss += loss_object(labels, predictions[1])
    loss += loss_object(labels, predictions[2])
  gradients = tape.gradinet(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  train_loss(loss)
  train_accuracy(labels, predictions[2])

@tf.function
def test_step(model, t_inputs, t_labels, loss_object, test_loss, test_accuracy):
  predictions = model(t_inputs,training=False)
  t_loss = loss_object(t_labels, predictions[2])
  test_loss(t_loss)
  test_accuracy(labels, predictions)