In [None]:
import typing
import tensorflow as tf
import matplotlib.pyplot as plt


class Classifier:

    def __init__(self,
                 num_of_classes: int,
                 epochs: int = 8,
                 batchsize: int = 128,
                 valsplit: int = 0.1):
        self.num_of_classes = num_of_classes
        self.epochs = epochs
        self.batchsize = batchsize
        self.valsplit = valsplit


    def load_pretrained_model(self, path: str):
        self.model = tf.keras.models.load_model(filepath=path)


    def export_model(self, path: str):
        tf.keras.Model.save(self.model, filepath=path)

    def train(self, training_data, validation_data):
        history = self.model.fit(training_data,
                                 epochs=self.epochs,
                                 validation_data=validation_data,
                                 shuffle=True)
        self.history = history


    def classify(self, data: str, height: int, width: int):
        img = tf.keras.utils.load_img(
            data, target_size=(height, width)
        )
        img_array = tf.keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)

        predictions = self.model.predict(img_array)
        score = tf.nn.softmax(predictions)

        return score


    def make_graph_from_history(self):
        acc = self.history.history['accuracy']
        val_acc = self.history.history['val_accuracy']
        print("Calculating the loss")
        loss = self.history.history['loss']
        val_loss = self.history.history['val_loss']
        

        epochs_range = range(self.epochs)
        print("The results are being visualized")
        plt.figure(figsize=(8, 8))
        plt.subplot(1, 2, 1)
        plt.plot(epochs_range, acc, label='Training Accuracy')
        plt.plot(epochs_range, val_acc, label='Validation Accuracy')
        plt.legend(loc='lower right')
        plt.title('Training and Validation Accuracy')
        plt.subplot(1, 2, 2)

        plt.plot(epochs_range, loss, label='Training Loss')
        plt.plot(epochs_range, val_loss, label='Validation Loss')
        plt.legend(loc='upper right')
        plt.title('Training and Validation Loss')
        plt.show()
    
    def test_accuracy(self, test_data_dir: str,
                      img_width: int = 512,
                      img_height: int = 512):
        test_ds = nu.load_dataset(data_dir=test_data_dir,
                        img_width=img_width,
                        img_height=img_height)

        self.evaluation = self.model.evaluate(test_ds)
    
def load_dataset(data_dir: str,
                 subset: str = None,
                 validation_split: int = 0.1,
                 img_width: int = 512,
                 img_height: int = 512,
                 batch_size: int = 128,
                 seed: int = 123):
    image_size = (img_height, img_width)
    if subset is not None:
        ds = tf.keras.utils.image_dataset_from_directory(
            directory=data_dir,
            subset=subset,
            validation_split=validation_split,
            image_size=image_size,
            batch_size=batch_size,
            seed=seed
        )

    else:
        ds = tf.keras.utils.image_dataset_from_directory(
            directory=data_dir,
            image_size=image_size,
            batch_size=batch_size
        )
    return ds


def pretrained_model(directory: str):
    model = tf.saved_model.load(directory)
    return model



def predict_image(img_dir: str, img_width: int, img_height: int, model):
    img = tf.keras.utils.load_img(path=img_dir,
                                  target_size=(img_height, img_width))
    model.predict(img)

class Alfonzo(Classifier):

    def __init__(self,
                 num_of_classes: int):
        super().__init__(num_of_classes)
        self.model = tf.keras.models.Sequential(
            [
                tf.keras.layers.Rescaling(1. / 255),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(units=128, activation='relu'),
                tf.keras.layers.Dense(units=64, activation='relu'),
                tf.keras.layers.Dense(units=32, activation='relu'),
                tf.keras.layers.Dense(units=16, activation='relu'),
                tf.keras.layers.Dense(units=num_of_classes)
            ])
        self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                           loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                           metrics="accuracy"),


alfonz = Alfonzo(3)

training_set = load_dataset(data_dir='/kaggle/input/odpadky/dataset/train',
                            img_width=512,
                            img_height=512,
                            subset="training",
                            validation_split=alfonz.valsplit,
                            seed=123,
                            batch_size=alfonz.batchsize)
validation_set = load_dataset(data_dir='/kaggle/input/odpadky/dataset/train',
                             img_width=512,
                             img_height=512,
                             subset="validation",
                             validation_split=alfonz.valsplit,
                             seed=123,
                             batch_size=alfonz.batchsize)
alfonz.model.build(input_shape=(1, 512, 512, 3))
alfonz.model.summary()
alfonz.train(training_data=training_set,
             validation_data=validation_set)

print("History is ", alfonz.history)

alfonz.make_graph_from_history()

test_ds = load_dataset(data_dir='/kaggle/input/odpadky/dataset/test',
                        img_width=512,
                        img_height=512)

alfonz.evaluation = alfonz.model.evaluate(test_ds)