In [2]:
import os
import glob
import logging
import numpy as np
import seaborn as sns
from typing import Dict
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import History

In [8]:
class CreatePlot:
    """
    Class for creating model plots
    """

    def __init__(self, directory: str = "/home/jovyan/work/model/plots") -> None:
        """
        Args:
            directory: A path to the directory to save an image
        """
        self.directory = directory
    
    def clear_directory(self) -> None:
        """
        Deletes all files in the directory
        """
        try:
            files = glob.glob(self.directory + '/*')
            for file in files:
                os.remove(file)
        except Exception as e:
            logging.error(f"Error while clearing directory {self.directory}; Reasons: {e}")

    def visualize_model(self, model: Model) -> None:
        """
        Creates a TensorFlow neural network visualization
        
        Args:
            model: Allready existing neural network
        """
        try:
            file_name = "model_schema.png"
            file_path = self.directory + "/" + file_name
            os.makedirs(self.directory, exist_ok=True)

            plot_model(model, 
                       show_dtype=True,
                       show_layer_names=True,
                       show_shapes=True,
                       to_file=file_path)
            logging.info(f"Successfully saved model visualization into file: {file_path}")
        except Exception as e:
            logging.error(f"Error while creating model visualization: {e}")
            raise e
    
    def conf_matrix(self, conf_mat: np.ndarray, title: str = "") -> None:
        """
        Creates a confusion matrix plot
        
        Args:
            conf_mat: Confusion matrix to display
            title: A name of label as a title for a plot
        """
        try:
            plt.figure(figsize=(12, 7))
            sns.heatmap(conf_mat, cmap="Blues", annot=True, fmt='d')
            plt.xlabel("Predicted Scores")
            plt.ylabel("True Scores")
            plt.title(f"Confusion matrix for {title}_score")

            file_path = self.directory + "/" + title + "_conf_mat"
            plt.savefig(file_path)
            plt.close()
            logging.info(f"Successfully saved an image to the file: {file_path}")
        except Exception as e:
            logging.error(f"Error while creating confusion matrix visualization: {e}")
            raise e
    
    def metrics_history(self, history: History, metrics: Dict[str, str]) -> None:
        """
        Plot history of improvement metrics during training on epochs
        
        Args:
            history: A dictionary storing metrics values from training
            metrics: Dictionary with names of loss and metrics algorithms used in model
        """
        try:
            hm_sc = "home_score_"
            aw_sc = "away_score_"
            metrics_names = []

            for key in metrics.keys():
                if key == "loss":
                    metrics_names.append(metrics[key])
                    loss_list = [key, hm_sc + key, aw_sc + key]
                else:
                    metric_name = metrics[key]
                    metric_list = [hm_sc + metric_name, aw_sc + metric_name]
                    metrics_names.append(metric_name)

            epochs = range(1, len(history.history[loss_list[0]]) + 1)
            all_metrics = [loss_list, metric_list]
            for x, list_met in enumerate(all_metrics):
                for met in list_met:
                    plt.plot(epochs, history.history[met], label=met)

                plt.xlabel("Epochs")
                plt.ylabel("Metrics")
                plt.title(metrics_names[x])
                plt.legend()
                plt.grid(visible=True)

                file_path = self.directory + "/" + metrics_names[x]
                plt.savefig(file_path)
                plt.close()
                logging.info(f"Successfully saved an image to the file: {file_path}")
        except Exception as e:
            logging.error(f"Error while creating metrics visualization: {e}")
            raise e