In [1]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, Dense, MaxPooling2D
)
from tensorflow.keras.models import Model

# Define a residual block
def residual_block(x, filters, stride=1):
    """
    A building block of ResNet.
    - x: Input to the block.
    - filters: Number of filters (output channels) for the convolutions.
    - stride: Step size for the convolution (default is 1).
    """
    # Save the input as "shortcut" to add it back later
    shortcut = x

    # First convolutional layer
    out = Conv2D(filters, kernel_size=(3, 3), strides=stride, padding='same')(x)
    out = BatchNormalization()(out)  # Normalize values to help with training
    out = ReLU()(out)  # Apply ReLU activation for non-linearity

    # Second convolutional layer
    out = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='same')(out)
    out = BatchNormalization()(out)  # Normalize again

    # Adjust the shortcut if dimensions do not match
    if stride != 1 or x.shape[-1] != filters:
        # Apply 1x1 convolution to match dimensions
        shortcut = Conv2D(filters, kernel_size=(1, 1), strides=stride, padding='same')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    # Add the shortcut to the output of the main path
    out = Add()([out, shortcut])
    out = ReLU()(out)  # Apply ReLU to the result
    return out

# Build the ResNet model
def build_resnet(input_shape, num_classes, num_blocks):
    """
    Constructs the ResNet model.
    - input_shape: Shape of the input images (e.g., (224, 224, 3)).
    - num_classes: Number of output classes for classification.
    - num_blocks: List specifying the number of blocks in each stage.
    """
    # Input layer
    inputs = Input(shape=input_shape)

    # Initial convolution and max pooling layers
    x = Conv2D(64, kernel_size=(7, 7), strides=2, padding='same')(inputs)  # Extract features
    x = BatchNormalization()(x)  # Normalize values
    x = ReLU()(x)  # Non-linear activation
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)  # Downsample

    # Residual stages (stack of residual blocks)
    filters = 64  # Starting number of filters
    for stage in range(len(num_blocks)):
        for block in range(num_blocks[stage]):
            stride = 2 if block == 0 and stage > 0 else 1  # Downsample at the first block of each stage
            x = residual_block(x, filters, stride)
        filters *= 2  # Double the filters for the next stage

    # Global Average Pooling
    x = GlobalAveragePooling2D()(x)  # Convert to a single feature vector per image

    # Fully connected (Dense) layer for classification
    outputs = Dense(num_classes, activation='softmax')(x)  # Softmax for class probabilities

    # Create the model
    model = Model(inputs, outputs)
    return model

# Example usage
if __name__ == "__main__":
    # Define input shape (e.g., 224x224 RGB images) and number of classes
    input_shape = (224, 224, 3)
    num_classes = 10  # Example: 10 classes for CIFAR-10 dataset

    # Build the ResNet-18 model
    model = build_resnet(input_shape, num_classes, num_blocks=[2, 2, 2, 2])

    # Show the model summary
    model.summary()
