In [None]:
import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np

In [None]:
base_dir = 'drive/MyDrive/datasetX'

train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'val')
#auto_test_dir = os.path.join(base_dir, 'auto_test')
#test_dir = os.path.join(base_dir, 'test')

In [None]:
train_datagen = ImageDataGenerator(
    rescale = 1./255,
    rotation_range = 40,
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    shear_range = 0.2,
    zoom_range = 0.2,
    horizontal_flip = True,
    fill_mode = 'nearest'
)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size = (224, 224),
    batch_size = 40,
    class_mode = 'categorical'
)

validation_generator = test_datagen.flow_from_directory(
    validation_dir,
    target_size = (224, 224),
    batch_size = 15,
    class_mode = 'categorical'
)

Found 940 images belonging to 5 classes.
Found 155 images belonging to 5 classes.


In [None]:
class STNLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(STNLayer, self).__init__(**kwargs)
        # Localization net: initially approximates identity transformation
        # Matrix [1, 0, 0, 0, 1, 0] reshaped to (2, 3) represents the identity transform
        b = np.array([[1, 0, 0], [0, 1, 0]]).astype('float32').flatten()
        self.localization_net = tf.keras.Sequential([
            layers.Conv2D(8, (7, 7), activation='relu'),
            layers.MaxPooling2D(2, 2),
            layers.Conv2D(10, (5, 5), activation='relu'),
            layers.MaxPooling2D(2, 2),
            layers.Flatten(),
            layers.Dense(32, activation='relu'),
            layers.Dense(6, activation='linear', bias_initializer=tf.keras.initializers.Constant(b))
        ])

    def call(self, x):
        theta = self.localization_net(x)
        theta = tf.reshape(theta, (-1, 2, 3))
        grid = self._get_grid(x, theta)
        transformed_x = self._sample(x, grid)
        return transformed_x

    def _get_grid(self, x, theta):
        batch_size, height, width = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        grid = self._meshgrid(height, width)
        grid = tf.expand_dims(grid, 0)
        grid = tf.tile(grid, [batch_size, 1, 1])
        grid_transformed = tf.matmul(theta, grid)
        grid_transformed = tf.reshape(grid_transformed, [batch_size, 2, height, width])
        grid_transformed = tf.transpose(grid_transformed, [0, 2, 3, 1])
        return grid_transformed

    def _meshgrid(self, height, width):
        x_t = tf.linspace(-1.0, 1.0, width)
        y_t = tf.linspace(-1.0, 1.0, height)
        x_t, y_t = tf.meshgrid(x_t, y_t)
        x_t_flat = tf.reshape(x_t, (1, -1))
        y_t_flat = tf.reshape(y_t, (1, -1))
        ones = tf.ones_like(x_t_flat)
        grid = tf.concat([x_t_flat, y_t_flat, ones], 0)
        return grid

    def _sample(self, img, grid):
        return bilinear_sampler(img, grid[..., 0], grid[..., 1])

def bilinear_sampler(img, x, y):

    """
    Performs bilinear sampling of the input image according to the normalized coordinates.

    - img: Batch of images in (B, H, W, C) layout.
    - x, y: Normalized x, y coordinates of the grid, in (B, H, W) layout.

    Returns interpolated images according to grids.
    """
    B = tf.shape(img)[0]
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    C = tf.shape(img)[3]
    max_y = tf.cast(H - 1, 'int32')
    max_x = tf.cast(W - 1, 'int32')
    zero = tf.zeros([], dtype='int32')

    # Scale x, y from [-1, 1] to [0, W/H - 1]
    x = tf.cast((x + 1.0) * tf.cast(W, 'float32') / 2.0, 'float32')
    y = tf.cast((y + 1.0) * tf.cast(H, 'float32') / 2.0, 'float32')

    # Ensure x, y are within the boundaries
    x = tf.clip_by_value(x, 0., tf.cast(max_x, 'float32'))
    y = tf.clip_by_value(y, 0., tf.cast(max_y, 'float32'))

    # Get pixel value at corner coords
    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    x0 = tf.clip_by_value(x0, zero, max_x)
    x1 = tf.clip_by_value(x1, zero, max_x)
    y0 = tf.clip_by_value(y0, zero, max_y)
    y1 = tf.clip_by_value(y1, zero, max_y)

    # Calculate bilinear interpolation
    Ia = get_pixel_value(img, x0, y0)
    Ib = get_pixel_value(img, x1, y0)
    Ic = get_pixel_value(img, x0, y1)
    Id = get_pixel_value(img, x1, y1)

    x = tf.cast(x, tf.float32)
    y = tf.cast(y, tf.float32)
    x0 = tf.cast(x0, tf.float32)
    x1 = tf.cast(x1, tf.float32)
    y0 = tf.cast(y0, tf.float32)
    y1 = tf.cast(y1, tf.float32)

    wa = (x1 - x) * (y1 - y)
    wb = (x1 - x) * (y - y0)
    wc = (x - x0) * (y1 - y)
    wd = (x - x0) * (y - y0)
    wa = tf.expand_dims(wa, axis=-1)
    wb = tf.expand_dims(wb, axis=-1)
    wc = tf.expand_dims(wc, axis=-1)
    wd = tf.expand_dims(wd, axis=-1)

    out = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
    return out

def get_pixel_value(img, x, y):
    """
    Utility function to get pixel values for coordinate
    vectors x and y from a 4D tensor image.

    img: tensor of shape (B, H, W, C)
    x, y: Tensors of shape (B*H*W,), indicating the x and y coordinates respectively.
    """
    shape = tf.shape(img)
    B, H, W, C = shape[0], shape[1], shape[2], shape[3]

    batch_indices = tf.range(B)
    batch_indices = tf.reshape(batch_indices, [B, 1, 1])
    batch_indices = tf.tile(batch_indices, [1, H, W])

    # Flatten x, y, and batch_indices to use with tf.gather_nd
    flat_x = tf.reshape(x, [-1])
    flat_y = tf.reshape(y, [-1])
    flat_batch_indices = tf.reshape(batch_indices, [-1])

    # Stack to create indices for gathering
    indices = tf.stack([flat_batch_indices, flat_y, flat_x], axis=1)

    # Use tf.gather_nd to gather values from img tensor
    result = tf.gather_nd(img, indices)

    # Reshape result back to (B, H, W, C) shape
    result = tf.reshape(result, [B, H, W, C])
    return result


In [None]:
# Integrating STN into the CNN
def build_cnn_with_stn():
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    # Apply STN transformation
    transformed_x = STNLayer()(inputs)
    # Convolutional layers
    x = layers.Conv2D(32, (3, 3), activation='relu')(transformed_x)
    x = layers.MaxPooling2D(2, 2)(x)
    x = layers.Conv2D(64, (3, 3), activation='relu')(x)
    x = layers.MaxPooling2D(2, 2)(x)
    x = layers.Conv2D(128, (3, 3), activation='relu')(x)
    x = layers.MaxPooling2D(2, 2)(x)
    # Dense layers
    x = layers.Flatten()(x)
    x = layers.Dense(512, activation='relu')(x)
    outputs = layers.Dense(5, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
def visualize_transformation(original_image, transformed_image):
  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10, 5))

  ax1.imshow(original_image)
  ax1.set_title("Original Image")
  ax1.axis("off")

  ax2.imshow(transformed_image)
  ax2.set_title("Transformed Image")
  ax2.axis("off")

  plt.tight_layout()
  plt.show()

In [None]:
model = build_cnn_with_stn()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Training the model
history = model.fit(
    train_generator,
    steps_per_epoch=15,
    epochs=10,
    validation_data=validation_generator,
    validation_steps=8,
    verbose=2
)

In [None]:
stn_model = Model(inputs=model.input, outputs=model.get_layer('stn_layer').output)

# Visualize the original and transformed images
val_images, val_labels = next(iter(validation_generator))
transformed_images = stn_model.predict(val_images)
predicted_labels = model.predict(val_images)

num_examples = 20
for i in range(num_examples):
    original_image = val_images[i]
    transformed_image = transformed_images[i]
    true_label = np.argmax(val_labels[i])  # Convert true label to class index
    predicted_label = np.argmax(predicted_labels[i])  # Convert predicted label to class index

    original_image = (original_image * 255).astype(np.uint8)
    transformed_image = (transformed_image * 255).astype(np.uint8)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.imshow(original_image)
    ax1.set_title(f"Original Image (True Label: {true_label})")
    ax1.axis("off")
    ax2.imshow(transformed_image)
    ax2.set_title(f"Transformed Image (Predicted Label: {predicted_label})")
    ax2.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()