# Imports

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Add, ReLU, MaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.activations import relu
from tensorflow.keras.datasets import mnist
import tensorflow_datasets as tfds

# Model

In [11]:
class IdentityBlock(Model):
    
    def __init__(self, filters, kernel_size):
        super().__init__()
        self.conv1=Conv2D(filters, kernel_size, padding='same')
        self.bn1=BatchNormalization()
        self.conv2=Conv2D(filters, kernel_size, padding='same')
        self.bn2=BatchNormalization()
        self.act=tf.keras.layers.Activation('relu')
        self.add=Add()
        
    def call(self, inputs):
        x=self.conv1(inputs)
        x=self.bn1(x)
        x=self.act(x)
        x=self.conv2(x)
        x=self.bn2(x)
        x=self.act(x)
        x=self.add([x, inputs])
        x=self.act(x)
        return x       

In [12]:
class ResNet(Model):
    
    def __init__(self, num_classes):
        super().__init__()
        self.conv=Conv2D(64, 3, padding='same')
        self.bn=BatchNormalization()
        self.max_pool=MaxPooling2D((3, 3))
        self.identity1=IdentityBlock(64, 3)
        self.identity2=IdentityBlock(64, 3)
        self.global_pool=GlobalAveragePooling2D()
        self.classifier=Dense(num_classes, activation='softmax')
        
    def call(self, inputs):
        
        x=self.conv(inputs)
        x=self.bn(x)
        x=self.max_pool(x)
        x=self.identity1(x)
        x=self.identity2(x)
        x=self.global_pool(x)
        x=self.classifier(x)
        
        return x                  

# Train (first_soloution)

In [13]:
def preprocess(features):
    return tf.cast(features['image'], tf.float32) / 255. , features['label']

resnet = ResNet(10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data', batch_size=32)
dataset = dataset.map(preprocess)

resnet.fit(dataset, epochs=1)



<keras.callbacks.History at 0x18ab014d5b0>

# Train (second_soloution)

In [14]:
(train_examples, train_labels),(test_examples, test_labels)=mnist.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

BATCH_SIZE = 32

train_dataset = train_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

resnet = ResNet(10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
print(dataset)
resnet.fit(dataset, epochs=1)

<MapDataset element_spec=(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>


<keras.callbacks.History at 0x18ab159fd00>