### Importing the necessary libraries

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, Flatten, Dense, Reshape, Lambda, Concatenate
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import skimage as sk
from skimage import io, color, measure
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

### Extracting the dataset

In [None]:
import zipfile
import os

zip_path = "dataset.zip"
extract_to = os.path.join(os.getcwd(), "dataset")

# Create the target directory if it doesn't exist
os.makedirs(extract_to, exist_ok=True)

# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print(f"Extracted '{zip_path}' to '{extract_to}'")

### Loading the images and their corresponding labels from the dataset

In [None]:
from utils import load_data, count_labels

# Parameters
data_dir = 'dataset/dataset' # Specify the path to your data directory
img_height, img_width, img_channels = 256, 256, 1 # Input image dimensions
num_classes = 9 # Number of classes

composition_labels = {
    27: 0,
    31: 1,
    35: 2,
    39: 3,
    40: 4,
    42: 5,
    44: 6,
    46: 7,
    48: 8
}

images, labels = load_data(data_dir, img_height, img_width, img_channels, num_classes, composition_labels)

print(f"{len(images)} images found")
print(f"{len(labels)} labels found")
count_labels(labels)

### Model Architecture

In [None]:
latent_dim = (16, 16, 1) # Latent space dimensions

class CVAE(tf.keras.Model):
    def __init__(self, latent_dim, num_classes):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()

    def build_encoder(self):
        input_img = Input(shape=(img_height, img_width, img_channels))
        input_label = Input(shape=(self.num_classes,))
        label_embedding = Dense(img_height * img_width)(input_label)
        label_embedding = Reshape((img_height, img_width, 1))(label_embedding)

        x = Concatenate()([input_img, label_embedding])
        x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D((2, 2), padding='same')(x)
        x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D((2, 2), padding='same')(x)
        x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D((2, 2), padding='same')(x)
        x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D((2, 2), padding='same')(x)
        x = Flatten()(x)
        mean = Dense(self.latent_dim[0] * self.latent_dim[1] * self.latent_dim[2])(x)
        log_var = Dense(self.latent_dim[0] * self.latent_dim[1] * self.latent_dim[2])(x)
        mean = Reshape(self.latent_dim)(mean)
        log_var = Reshape(self.latent_dim)(log_var)
        return Model([input_img, input_label], [mean, log_var], name='encoder')

    def build_decoder(self):
        latent_input = Input(shape=self.latent_dim)
        input_label = Input(shape=(self.num_classes,))
        label_embedding = Dense(self.latent_dim[0] * self.latent_dim[1] * self.latent_dim[2])(input_label)
        label_embedding = Reshape(self.latent_dim)(label_embedding)

        x = Concatenate()([latent_input, label_embedding])
        x = Reshape((self.latent_dim[0] * self.latent_dim[1] * self.latent_dim[2] * 2,))(x)
        x = Dense(16 * 16 * 512, activation='relu')(x)
        x = Reshape((16, 16, 512))(x)
        x = Conv2DTranspose(256, (3, 3), activation='relu', strides=2, padding='same')(x)
        x = Conv2DTranspose(128, (3, 3), activation='relu', strides=2, padding='same')(x)
        x = Conv2DTranspose(64, (3, 3), activation='relu', strides=2, padding='same')(x)
        x = Conv2DTranspose(32, (3, 3), activation='relu', strides=2, padding='same')(x)
        output_img = Conv2D(img_channels, (3, 3), activation='sigmoid', padding='same')(x)
        return Model([latent_input, input_label], output_img, name='decoder')

    def sampling(self, args):
        mean, log_var = args
        epsilon = tf.random.normal(shape=tf.shape(mean), mean=0., stddev=1.)
        return mean + tf.exp(log_var / 2) * epsilon

    def call(self, inputs):
        input_img, input_label = inputs
        mean, log_var = self.encoder([input_img, input_label])
        z = self.sampling([mean, log_var])
        reconstructed = self.decoder([z, input_label])
        reconstruction_loss = tf.reduce_mean(MeanSquaredError()(input_img, reconstructed))
        kl_loss = -0.5 * tf.reduce_mean(1 + log_var - tf.square(mean) - tf.exp(log_var))
        self.add_loss(reconstruction_loss + kl_loss)
        return reconstructed

cvae = CVAE(latent_dim, num_classes)
cvae.compile(optimizer=Adam())

cvae.encoder.summary()
cvae.decoder.summary()

### Training the model

In [None]:
# Create Adam optimizer with initial learning rate
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# # Set up callbacks
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=30
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.85,  # Reducing lr by 5% when triggered
    patience=5,   # Number of epochs with no improvement after which learning rate will be reduced
    min_lr=1e-6,  # LR won't reduce below this
    min_delta=1e-3  # Minimum change to count as an improvement
)

# Splitting the dataset into training and validation sets
labels_integer = np.argmax(labels, axis=1)
train_images, val_images, train_labels, val_labels = train_test_split(
    images, labels, 
    test_size=0.2, 
    stratify=labels_integer
)

from tensorflow.keras.callbacks import ModelCheckpoint

# Training with both the callbacks
cvae.fit(
    [train_images, train_labels],
    train_images,
    epochs=500,
    batch_size=16,
    validation_data=([val_images, val_labels], val_images),
    callbacks=[reduce_lr,early_stopping]
)

### Saving the model weights after training

In [None]:
cvae.save_weights('saved_model_weights.h5')  # Specify the file name to save the weights