In [5]:
# !conda install pydot -y
# !conda install pydotplus -

In [2]:
import matplotlib.pyplot as plt
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model
import logging
import os
import numpy as np
from typing import Dict, List
import seaborn as sns
from tensorflow.keras.callbacks import History

In [8]:
class CreatePlot:
    """
    Class for creating model plots
    """

    def __init__(self, 
                 directory: str = "/home/jovyan/work/model/plots"
                 ) -> None:
        """
        Args:
            directory: A path to the directory to save an image
        """
        self.directory = directory

    def visualize_model(self, model: Model) -> None:
        """
        Creates a TensorFlow neural network visualization
        
        Args:
            model: Allready existing neural network
        """
        try:
            file_name = "model_schema.png"
            file_path = self.directory + "/" + file_name
            os.makedirs(self.directory, exist_ok=True)

            plot_model(model, 
                       show_dtype=True,
                       show_layer_names=True,
                       show_shapes=True,
                       to_file=file_path)
            logging.info("Successfully created model visualization")
        except Exception as e:
            logging.error(f"Error while creating model visualization: {e}")
            raise e
    

    def differences(self, 
                    true_value: np.ndarray, 
                    predicted_value: np.ndarray, 
                    title: str = ""
                    ) -> None:
        """
        Visualizes differences between true and predicted labels
        
        Args:
            true_value: An array containing true values 
            predicted_value: An array containing predicted values
            title: A name of label as a title for a plot
        """
        try:
            predicted_value = predicted_value.reshape((len(true_value, )))
            residuals = np.abs(true_value - predicted_value)
            plt.scatter(true_value, residuals)
            plt.xlabel('True Values')
            plt.ylabel('Predicted Values')
            plt.title(f"Differences between true and predicted labels in {title}_score")
            plt.grid(True)

            os.makedirs(self.directory, exist_ok=True)
            file_path = self.directory + "/" + title + "_difference"
            plt.savefig(file_path)
            plt.close()
            logging.info(f"Successfully saved an image to the file: {file_path}")
        except Exception as e:
            logging.error(f"Erorr in visualizaion differences between true and prediction labels: {e}")
            raise e
    
    def conf_matrix(self,
                    conf_mat: np.ndarray,
                    title: str = ""
                    ) -> None:
        """
        Creates a confusion matrix plot
        
        Args:
            conf_mat: Confusion matrix to display
            title: A name of label as a title for a plot
        """
        try:
            plt.figure(figsize=(12, 7))
            sns.heatmap(conf_mat, cmap="Blues", annot=True, fmt='d')
            plt.xlabel("Predicted Scores")
            plt.ylabel("True Scores")
            plt.title(f"Confusion matrix for {title}_score")

            file_path = self.directory + "/" + title + "_confusion_matrix"
            plt.savefig(file_path)
            plt.close()
            logging.info(f"Successfully saved an image to the file: {file_path}")
        except Exception as e:
            logging.error(f"Error while creating confusion matrix visualization: {e}")
            raise e
    
    def metrics_history(self, history: History, metrics: List[str]) -> None:
        """
        Plot history of improvement metrics during training on epochs
        
        Args:
            history: A dictionary storing metrics values from training
            metrics: A list of keys from history dictionary to plot
        """
        try:
            for metric in metrics:
                plt.plot(history.history[metric], label=metric)

            plt.xlabel("Epochs")
            plt.ylabel("Metrics")
            plt.title("Metrics history")
            plt.legend()
            plt.grid(visible=True)

            file_path = self.directory + "/metrics"
            plt.savefig(file_path)
            plt.close()
            logging.info(f"Successfully saved an image to the file: {file_path}")
        except Exception as e:
            logging.error(f"Error while creating metrics visualization: {e}")
            raise e
