# Model Training

The Model Training file is designed to train an autoencoder model with a Spatial Transformer Network (STN) layer integrated into the ResNet50 architecture. The STN layer helps the model learn spatial transformations (such as rotation, scaling, and translation) to better align input images, which enhances the model's ability to generalize. ResNet50 acts as the feature extractor, capturing deep representations of the input data, while the autoencoder focuses on compressing and reconstructing the images. This file includes the training process, optimization techniques, and hyperparameter settings used to effectively train the model for accurate image reconstruction and transformation correction.

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, Flatten, Reshape, Conv2DTranspose
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

2024-09-29 11:23:39.309099: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-29 11:23:39.327475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-29 11:23:39.348831: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-29 11:23:39.355258: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-29 11:23:39.371199: I tensorflow/core/platform/cpu_feature_guar

In [2]:
datagen = ImageDataGenerator(
    rescale=1./255,
    fill_mode="nearest",
    )
train_dataset = datagen.flow_from_directory(
    directory="./dataset",
    target_size=(224,224),
    class_mode=None,
    batch_size=32
)

Found 78154 images belonging to 39 classes.


This model integrates a Spatial Transformer Network (STN) layer with a ResNet50-based autoencoder architecture. The STN layer is designed to learn spatial transformations such as rotation, scaling, and translation, allowing the model to dynamically align and correct the input images before passing them through the network. The encoder uses a pre-trained ResNet50 architecture as a feature extractor, which captures deep, meaningful representations of the input data. The output of the encoder is then processed through a series of fully connected layers, culminating in a 512-dimensional encoded representation. The decoder reverses this process, progressively upsampling the encoded data through convolutional transpose layers and upsampling layers to reconstruct the image. The STN helps the model achieve spatial invariance, while the autoencoder focuses on encoding and reconstructing images, resulting in an efficient image correction and generation model.

In [3]:
class STN(layers.Layer):
    def __init__(self):
        super(STN, self).__init__()

    def build(self, input_shape):
        self.localization = tf.keras.Sequential([
            layers.Conv2D(16, (7, 7), activation='relu', input_shape=input_shape[1:]),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(32, (5, 5), activation='relu'),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dense(50, activation='relu'),
            layers.Dense(6, activation='linear')
        ])

    def call(self, inputs):
        theta = self.localization(inputs)
        theta = tf.reshape(theta, [-1, 2, 3])
        grid = self.get_grid(tf.shape(inputs), theta)
        return self.sampler(inputs, grid)

    def get_grid(self, input_shape, theta):
        batch_size, height, width = input_shape[0], input_shape[1], input_shape[2]
        x_coords = tf.linspace(-1.0, 1.0, width)
        y_coords = tf.linspace(-1.0, 1.0, height)
        x_grid, y_grid = tf.meshgrid(x_coords, y_coords)
        ones = tf.ones_like(x_grid)
        grid = tf.stack([x_grid, y_grid, ones], axis=-1)
        grid = tf.reshape(grid, [1, height * width, 3])
        grid = tf.tile(grid, [batch_size, 1, 1])
        grid = tf.matmul(grid, tf.transpose(theta, [0, 2, 1]))
        return grid

    def sampler(self, inputs, grid):
        shape = tf.shape(inputs)
        batch_size = shape[0]
        height = shape[1]
        width = shape[2]
        channels = shape[3]
        resized_inputs = tf.image.resize(inputs, size=(height, width))
        return resized_inputs

def create_encoder(input_shape):
    inputs = layers.Input(shape=input_shape)
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    x = base_model(inputs)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(512, activation='relu')(x)

    return Model(inputs, x, name="encoder")

def decoder(encoded):
    x = layers.Dense(7 * 7 * 64, activation='relu')(encoded)
    x = layers.Reshape((7, 7, 64))(x)

    x = Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)  # 7x7 -> 14x14

    x = Conv2DTranspose(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)  # 14x14 -> 28x28

    x = Conv2DTranspose(16, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)  # 28x28 -> 56x56

    x = Conv2DTranspose(8, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)  # 56x56 -> 112x112

    x = Conv2DTranspose(3, (3, 3), activation='sigmoid', padding='same')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)  # 112x112 -> 224x224

    return x


In [4]:
def autoencoder_data_generator(generator):
    for batch in generator:
        yield (batch, batch)

In [5]:
train_datagen = autoencoder_data_generator(train_dataset)

In [6]:
inputs = tf.keras.Input(shape=(224, 224, 3))
stn_layer = STN()(inputs)
encoder_model = create_encoder(input_shape=(224, 224, 3))
encoded = encoder_model(stn_layer)
decoded = decoder(encoded)
autoencoder = Model(inputs=inputs, outputs=decoded)
autoencoder.compile(optimizer='adam', loss='mse')
autoencoder.summary()
autoencoder.fit(train_datagen, epochs=100, batch_size=32)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
2024-09-29 11:23:57.672488: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38364 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


In [7]:
autoencoder.save('autoencoder.h5')
print("Saved model to disk")

Saved model to disk
