In [306]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
from skimage import io
import numpy as np
from tensorflow.keras.layers import MaxPooling2D, Flatten, Dense
random_seed = 42

In [307]:
@tf.function
def conv_function(inputs, kernel, s, p, bias):
    z = tf.nn.conv2d(inputs, kernel, strides=[1, *s, 1], padding=p)
    return tf.nn.relu(z + bias)

class SimpleConvLayer(tf.keras.layers.Layer):
    def __init__(self, num_kernels=32, kernel_size=(3, 3), strides=(1, 1), padding='VALID'):
        super().__init__()
        self.num_kernels = num_kernels
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        
    def build(self, input_shape):
        num_channels = input_shape[-1]
        kernel_shape = (*self.kernel_size, num_channels, self.num_kernels)
        glorot_uni_init = tf.initializers.GlorotUniform(random_seed)
        self.kernels = self.add_weight(name='kernels', shape=kernel_shape, initializer=glorot_uni_init,
                                      trainable=True)
        self.bias = self.add_weight(name='bias', shape=(self.num_kernels, ), initializer='random_normal',
                                   trainable=True)
    
    def call(self, inputs):
        return conv_function(inputs, self.kernels, self.strides, self.padding, self.bias)

In [308]:
def l2_reg(coef=1e-2):
    return lambda x: tf.reduce_sum(x ** 2) * coef

In [309]:
from functools import partial

class ConvWithReg(SimpleConvLayer):
    def __init__(self, kernel_regularizer=l2_reg(), bias_regularizer=l2_reg(), **kwargs):
        super().__init__(**kwargs)
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        
    def build(self, input_shape):
        super().build(input_shape)
        if self.kernel_regularizer is not None:
            self.add_loss(partial(self.kernel_regularizer, self.kernels))
        if self.bias_regularizer is not None:
            self.add_loss(partial(self.bias_regularizer, self.bias))

In [310]:
conv_with_reg = ConvWithReg()

In [311]:
conv_with_reg.build(input_shape=(1, 200, 200, 3))

In [312]:
conv_with_reg.losses

[<tf.Tensor: shape=(), dtype=float32, numpy=0.0560476>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.00065425213>]

In [313]:
class LeNet5(tf.keras.models.Model):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = ConvWithReg(num_kernels=6, kernel_size=(5, 5))
        self.conv2 = ConvWithReg(num_kernels=16, kernel_size=(5, 5))
        self.max_pool = MaxPooling2D(pool_size=(2, 2))
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(64, activation='relu')
        self.dense3 = Dense(num_classes, activation='softmax')
        self.flatten = Flatten()
        
    def call(self, inputs):
        x = self.max_pool(self.conv1(inputs))
        x = self.max_pool(self.conv2(x))
        x = self.flatten(x)
        x = self.dense3(self.dense2(self.dense1(x)))
        return x

In [314]:
lenet5 = LeNet5(num_classes)

In [315]:
lenet5.build(input_shape=(None, 100, 100, 3))

In [316]:
lenet5.summary()

Model: "le_net5_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_with_reg_37 (ConvWithRe multiple                  456       
_________________________________________________________________
conv_with_reg_38 (ConvWithRe multiple                  2416      
_________________________________________________________________
max_pooling2d_10 (MaxPooling multiple                  0         
_________________________________________________________________
dense_30 (Dense)             multiple                  991360    
_________________________________________________________________
dense_31 (Dense)             multiple                  8256      
_________________________________________________________________
dense_32 (Dense)             multiple                  650       
_________________________________________________________________
flatten_10 (Flatten)         multiple                  0

In [317]:
x = tf.random.normal((100, 100, 3))
x = tf.expand_dims(x, 0)

In [318]:
y = lenet5(x)



In [319]:
y.shape

TensorShape([1, 10])

In [320]:
conv = ConvWithReg()

In [321]:
conv.build(input_shape=(None, 100, 100, 3))

In [322]:
y = conv(x)



In [323]:
y.shape

TensorShape([1, 98, 98, 32])

In [324]:
conv = ConvWithReg(num_kernels=10, kernel_size=(3, 3))