In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

from tensorflow.keras.layers import Conv2D, BatchNormalization, GlobalAveragePooling2D, Dense

import numpy as np
import matplotlib.pyplot as plt


# ResNet Layer

In [2]:
class IdentityResidual(tf.keras.layers.Layer):
    def __init__(self, out_channels, stride):
        super().__init__()
        self.out_channels = out_channels
        self.stride = stride

    def build(self, input_shape):
        b, h, w, in_channels = input_shape
        self.in_channels = in_channels
        self.h = h // self.stride
        self.w = w // self.stride
        self.b = b
        self.c = self.out_channels - self.in_channels

    def call(self, input_tensor):
        # Downsample spatially
        x = input_tensor[:, ::self.stride, ::self.stride, :]
        # Create padding tensor for extra channels 
        if self.out_channels != self.in_channels:
            pad = tf.zeros((self.b, self.h, self.w, self.c))
            # Append padding to the downsampled identity
            x = tf.concat((x, pad), axis=-1)
        return x

class ResNetV2Layer(tf.keras.Model):
    def __init__(self, channels, stride=1):
        super().__init__()
        conv_kwargs = {
            "padding": "same",
            "use_bias": False
        }
        self.stride = stride
        self.channels = channels
        self.relu = tf.nn.relu
        self.residual = IdentityResidual(channels, stride)
        self.conv1 = Conv2D(filters=channels, kernel_size=3, strides=self.stride, **conv_kwargs)
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters=channels, kernel_size=3, **conv_kwargs)
        self.bn2 = BatchNormalization()
    
    def call(self, input_tensor, training=False):
        residual = self.residual(input_tensor)
        x = self.bn1(input_tensor, training=training)
        x = self.relu(x)
        x = self.conv1(x)
        x = self.bn2(x, training=training)
        x = self.relu(x)
        x = self.conv2(x)
        return x + residual

In [3]:
layer = ResNetV2Layer(16)
layer_2 = ResNetV2Layer(32, stride=2)
inputs = tf.random.normal((4, 32, 32, 3))
z = layer(inputs)
z = layer_2(z)
z.shape

TensorShape([4, 16, 16, 32])

In [4]:
layer = ResNetV2Layer(32, stride=2)
inputs = tf.random.normal((4, 32, 32, 16))
z = layer(inputs)
z.shape

TensorShape([4, 16, 16, 32])

In [5]:
ResNetV2Model = tf.keras.Sequential([
    Conv2D(filters=16, kernel_size=3, padding="same", use_bias=False, data_format="channels_last"),
    ResNetV2Layer(16),
    ResNetV2Layer(16),
    ResNetV2Layer(16),
    ResNetV2Layer(32, stride=2),
    ResNetV2Layer(32),
    ResNetV2Layer(32),
    ResNetV2Layer(64, stride=2),
    ResNetV2Layer(64),
    ResNetV2Layer(64),
    GlobalAveragePooling2D(),
    Dense(10)
])

In [6]:
inputs = tf.random.normal((32, 32, 32, 3))
z = ResNetV2Model(inputs)

In [7]:
ResNetV2Model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_6 (Conv2D)           (32, 32, 32, 16)          432       
                                                                 
 res_net_v2_layer_3 (ResNetV  (32, 32, 32, 16)         4736      
 2Layer)                                                         
                                                                 
 res_net_v2_layer_4 (ResNetV  (32, 32, 32, 16)         4736      
 2Layer)                                                         
                                                                 
 res_net_v2_layer_5 (ResNetV  (32, 32, 32, 16)         4736      
 2Layer)                                                         
                                                                 
 res_net_v2_layer_6 (ResNetV  (32, 16, 16, 32)         14016     
 2Layer)                                                

# Dataloading

In [8]:
# train_ds, val_ds, test_ds = tfds.load(
#     "cifar10", 
#     split=["train[:90%]", "train[90%:]", "test"],
#     as_supervised=True)

train_ds, val_ds = tfds.load(
    "cifar10",
    split=["train", "test"],
    as_supervised=True
)

In [9]:
# len(train_ds), len(val_ds), len(test_ds)

In [10]:
std = [0.229, 0.224, 0.225]
var = [x ** 2 for x in std]

augment_pipeline = tf.keras.Sequential([
    tf.keras.layers.Rescaling(scale=1./255),
    tf.keras.layers.Normalization(mean=[0.485, 0.456, 0.406], variance=var),
    tf.keras.layers.ZeroPadding2D(padding=(4, 4)),
    tf.keras.layers.RandomFlip(mode="horizontal"),
    tf.keras.layers.RandomCrop(height=32, width=32)
])

evaluate_pipeline = tf.keras.Sequential([
    tf.keras.layers.Rescaling(scale=1./255),
    tf.keras.layers.Normalization(mean=[0.485, 0.456, 0.406], variance=var),
])

augment_pipeline.compile()
evaluate_pipeline.compile()

In [11]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).batch(32, drop_remainder=True).map(lambda x, y: (augment_pipeline(x, training=True), y))
val_ds = val_ds.cache().batch(32, drop_remainder=True).map(lambda x, y: (evaluate_pipeline(x, training=False), y))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


































# Setup training

In [12]:
# optimizer = tf.keras.optimizers.SGD(
#     learning_rate=0.01, momentum=0.9,
# )

In [13]:
ResNetV2Model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, weight_decay=1e-4),
    # optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(),
        tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)]
)

# Run Training


In [14]:
history = ResNetV2Model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10
)

2023-01-22 17:39:28.460750: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


5 epochs SGD 0.01, 0.9: val acc 73.8, val loss 0.7473, train acc 75.5, train loss 0.7066

10 epochs adamw 1e-3, 1e-4: val acc 74.9, val loss 0.848, train acc 84.0, train_loss 0.4858

In [15]:
history.history

{'loss': [1.505010724067688,
  1.053391933441162,
  0.8717251420021057,
  0.7565324902534485,
  0.676871120929718,
  0.6126886010169983,
  0.5663599967956543,
  0.5213607549667358,
  0.48580610752105713,
  0.457795649766922],
 'sparse_categorical_accuracy': [0.4596470892429352,
  0.6254401206970215,
  0.6955626010894775,
  0.7371158599853516,
  0.7679857611656189,
  0.7888524532318115,
  0.8055977821350098,
  0.8196423053741455,
  0.832626461982727,
  0.8404489159584045],
 'sparse_categorical_crossentropy': [1.505010724067688,
  1.053391933441162,
  0.8717251420021057,
  0.7565324902534485,
  0.676871120929718,
  0.6126886010169983,
  0.5663599967956543,
  0.5213607549667358,
  0.48580610752105713,
  0.457795649766922],
 'val_loss': [1.319312334060669,
  1.0919350385665894,
  1.0077588558197021,
  0.8891614079475403,
  0.7832885980606079,
  0.7087368965148926,
  0.7203460335731506,
  0.7719316482543945,
  0.8339065313339233,
  0.8480474948883057],
 'val_sparse_categorical_accuracy': [0