In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [2]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Layer, Conv2D, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D, Activation
from keras.layers import Add, Dense
from keras.models import Model
import tensorflow_datasets as tfds

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
class IdentityBlock(Model):
    def __init__(self, filters, kernel_size, **kwargs):
        super(IdentityBlock, self).__init__(**kwargs)

        self.conv1 = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')
        self.bn2 = BatchNormalization()
        self.act = Activation('relu')
        self.add = Add()

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.bn2(x)

        x = self.add([x, inputs])
        x = self.act(x)
        return x

In [4]:
class ResNet(Model):
    def __init__(self, num_classes, **kwargs):
        super(ResNet, self).__init__(**kwargs)

        self.conv1 = Conv2D(filters=64, kernel_size=7, padding='same')
        self.bn1 = BatchNormalization()
        self.maxpooling = MaxPooling2D(pool_size=(3, 3))
        self.act = Activation('relu')

        self.identity1 = IdentityBlock(64, 3)
        self.identity2 = IdentityBlock(64, 3)

        self.globalavgpooling = GlobalAveragePooling2D()
        self.classifier = Dense(units=num_classes, activation='softmax')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.maxpooling(x)

        x = self.identity1(x)
        x = self.identity2(x)

        x = self.globalavgpooling(x)
        x = self.classifier(x)
        return x

In [5]:
def preprocess(features):
    return tf.cast(features['image'], tf.float32) / 255., features['label']

In [6]:
model = ResNet(10)
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='rmsprop', 
    metrics=['accuracy']
)

dataset = tfds.load('mnist', split=tfds.Split.TRAIN)
dataset = dataset.map(preprocess).batch(32)

model.fit(dataset, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x7fcc132eb050>