In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import os
import numpy as np

In [None]:
path_to_data = r"../input/waste-dataset/data"

train_ds = tf.keras.utils.image_dataset_from_directory(
    path_to_data,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(256, 256),
    shuffle=True 
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    path_to_data,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(256, 256),
    shuffle=True 
)

train_ds

In [None]:
for image, label in train_ds:
    print(label.shape)

In [None]:
class_names = train_ds.class_names

plt.figure(figsize=(20, 20))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [None]:
IMG_SIZE = 100

resize_and_rescale = tf.keras.Sequential([
  # layers.Resizing(IMG_SIZE, IMG_SIZE),
  layers.Rescaling(1./255)
])

In [None]:
def histogram_equalise(image):
    """
    Perform histogram equalization on a grayscale image tensor.
    Args:
        image: A tensor of shape (height, width, 1) with dtype tf.float32 in the range [0, 1].
    Returns:
        A tensor of the same shape as the input with equalized histogram.
    """
    image_255 = tf.cast(image, tf.int32)
    
    hist = tf.histogram_fixed_width(image_255, [0, 255], nbins=256)
    
    cdf = tf.cumsum(hist)

    cdf_min = tf.reduce_min(tf.boolean_mask(cdf, cdf > 0))  
    
    equalized = tf.gather(cdf_min, image_255)
    
    return equalized

class EqualiseHistogram(tf.keras.layers.Layer):
    def __init__(self):
        super(EqualiseHistogram, self).__init__()
        
    def call(self, inputs):
        return tf.map_fn(
            histogram_equalise, 
            images, 
            dtype=tf.float32
        )

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        return super().get_config()


In [None]:
eq_hist = EqualiseHistogram()

data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
])

In [None]:
train_ds = train_ds.map(lambda x, y:(resize_and_rescale(x), y)).prefetch(buffer_size=tf.data.AUTOTUNE)
train_ds = train_ds.map(lambda x, y:(data_augmentation(x), y)).prefetch(buffer_size=tf.data.AUTOTUNE)

train_ds

In [None]:
plt.figure(figsize=(20, 20))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i])
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [None]:
class ResidualCNN:
    def __init__(self, activation="relu", kernel_size=3, train_ds=train_ds, val_ds=val_ds, shape=(256, 256, 3)):
        # Attribute definition
        self.activation = tf.keras.layers.Activation(activation)
        self.kernel_size = kernel_size
        self.shape = shape
        self.train_ds = train_ds
        self.val_ds = val_ds
        self._model = None

        # Define ResNet architecture
        self.resnet_filter_architecture = [([64] * 3), ([128] * 4), ([256] * 6), ([512] * 3)]

        # Build the model architecture
        self.model_architecture()

    def model_architecture(self):
        x_input = tf.keras.layers.Input(shape=self.shape)

        # Initial Convolution and Pooling
        x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same', kernel_initializer="he_normal")(x_input)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation('relu')(x)
        x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

        # Residual Blocks
        for filter_numbers in self.resnet_filter_architecture:
            first = True
            for filters in filter_numbers:
                if first:
                    # First block in the stage
                    x_temp = x
                    x_class = self.classify(filters=filters, strides=2)(x)
                    x_skip = self.skipify(filters=filters, strides=2)(x_temp)
                    x = tf.keras.layers.Add()([x_class, x_skip])
                    x = self.activation(x)
                    first = False
                else:
                    # Subsequent blocks
                    x_class = self.classify(filters=filters, strides=1)(x)
                    x = tf.keras.layers.Add()([x_class, x])
                    x = self.activation(x)

        # Final Layers
        x = tf.keras.layers.GlobalAvgPool2D()(x)
        x = tf.keras.layers.Flatten()(x)
        x_output = tf.keras.layers.Dense(4, activation='softmax')(x)

        # Create the model
        self._model = tf.keras.models.Model(inputs=x_input, outputs=x_output)

    def classify(self, filters, strides):
        """Defines the classification branch."""
        return tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                kernel_size=self.kernel_size,
                filters=filters,
                strides=strides,
                padding="same",
                kernel_initializer="he_normal"
            ),
            tf.keras.layers.BatchNormalization(),
            self.activation,
            tf.keras.layers.Conv2D(
                kernel_size=self.kernel_size,
                filters=filters,
                strides=1,  # Strides for the second convolution is always 1
                padding="same",
                kernel_initializer="he_normal"
            ),
            tf.keras.layers.BatchNormalization(),
        ])

    def skipify(self, filters, strides):
        """Defines the skip connection."""
        return tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                kernel_size=1,  # 1x1 convolution for skip connection
                filters=filters,
                strides=strides,
                padding="same",
                kernel_initializer="he_normal"
            ),
            tf.keras.layers.BatchNormalization(),
        ])

    def train_model(self):
        """Compiles and trains the model."""
        if self._model is not None:
            self._model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
            self._model.fit(self.train_ds, validation_data=self.val_ds, epochs=50)
        else:
            raise ValueError("Model is not defined. Ensure model_architecture is properly set.")

    def summary(self):
        """Prints the model summary."""
        if self._model is not None:
            return self._model.summary()
        else:
            raise ValueError("Model is not defined. Ensure model_architecture is properly set.")


In [None]:
resnet_cnn = ResidualCNN(train_ds=train_ds, val_ds=val_ds)

resnet_cnn.summary()

In [None]:
resnet_cnn.train_model()