# MODEL SUBCLASSING

## Permits us to create recursively composable layers and models

### -> Can create layers with it's attributes as other layers

In [1]:
# Import Dependencies
import tensorflow as tf
from tensorflow.keras.layers import Input, Normalization, Conv2D, MaxPooling2D, Dense, Flatten, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer

# FEATURE EXTRACTOR MODEL CLASS

In [2]:
# Model Class
class FeatureExtractor(Layer):
    # Initialization
    def __init__(self, filters, kernel_size, strides, padding, activation, pool_size):
        super(FeatureExtractor, self).__init__()

        self.conv1 = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, activation=activation)
        self.batch1 = BatchNormalization()
        self.pool1 = MaxPooling2D(pool_size=pool_size, strides=2*strides)
        self.conv2 = Conv2D(filters=filters*2, kernel_size=kernel_size, strides=strides, padding=padding, activation=activation)
        self.batch2 = BatchNormalization()
        self.pool2 = MaxPooling2D(pool_size=pool_size, strides=2*strides)
    
    # Call Function -> similar to Functional API
    def call(self, x):
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.batch2(x)
        x = self.pool2(x)
        return x

featureExtractorSubclassed = FeatureExtractor(filters=8, kernel_size=3, strides=1, padding='valid', activation='relu', pool_size=2)

# CREATING FINAL MODEL WITH SUB-CLASSED FEATURE EXTRACTOR USING FUNCTIONAL API

In [3]:
IMAGE_SIZE = 224

# Model Input
funcInput = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3), name = "ModelInput")

# Feature Extraction Layers
X = featureExtractorSubclassed(funcInput)
X = Flatten()(X)

X = Dense(100, activation='relu')(X)
X = BatchNormalization()(X)

X = Dense(10, activation='relu')(X)
X = BatchNormalization()(X)

# Output Layer
funcOutput = Dense(1, activation='sigmoid')(X)

# Model
LeNetModel = Model(funcInput, funcOutput, name = "LeNetModel")
LeNetModel.summary()

2025-01-20 18:34:29.321957: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-01-20 18:34:29.321974: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-01-20 18:34:29.321980: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
2025-01-20 18:34:29.321995: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-01-20 18:34:29.322003: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


# CREATING FINAL MODEL WITH SUB-CLASSED FEATURE EXTRACTOR USING MODEL-SUBCLASSING

In [None]:
# Model Class
class LeNetModel(Model):
    # Initialization
    def __init__(self):
        super(LeNetModel, self).__init__()
        # Calling Feature Extractor Classed Model as a Layer in Final Model
        self.featureExtractor = FeatureExtractor(filters=8, kernel_size=3, strides=1, padding='valid', activation='relu', pool_size=2)
        self.flatten = Flatten()
        self.dense1 = Dense(100, activation='relu')
        self.batch1 = BatchNormalization()
        self.dense2 = Dense(10, activation='relu')
        self.batch2 = BatchNormalization()
        self.dense3 = Dense(1, activation="sigmoid")

    # Call Function -> similar to Functional API
    def call(self, x):
        x = self.featureExtractor(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.batch1(x)
        x = self.dense2(x)
        x = self.batch2(x)
        x = self.dense3(x)
        return x
    
leNetSubclassedModel = LeNetModel()

# This Step Builds The Model -> By calling it on a dataset of zeroes we defined below
leNetSubclassedModel(tf.zeros([1, 224, 224, 3]))

leNetSubclassedModel.summary()

## Rest everything is same, which includes compiling and training the model and all other operations.