In [3]:
import logging
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import DataFrame
from pyspark.sql.functions import max, min
from abc import ABC, abstractmethod

In [5]:
class Graph(ABC):
    """
    Abstract Class for creating different graphs
    """
    @abstractmethod
    def __init__(self, 
                 df: DataFrame,
                 directory_path: str = "/home/jovyan/work/visualization/graphs/"
                 ) -> None:
        """
        Initializes parameters
        
        Args:
            df: DataFrame with data to display
            directory_path: Path to the directory storing images
        """
        self.directory_path = directory_path
        self.df = df
        self.x = self.df.select(self.df.columns[0]).toPandas().values.reshape(-1)
        self.y = self.df.select(self.df.columns[1]).toPandas().values.reshape(-1)
        
    @abstractmethod
    def create_graph(self, file_name: str, title: str) -> None:
        """
        Creates a graph

        Args:
            file_name: Name of file and title
            title: Title of graph
            """
        pass

In [6]:
class BarGraph(Graph):
    """
    Class for creating 'bar' graph
    """
    def __init__(self, 
                 df: DataFrame, 
                 file_name: str,
                 directory_path: str = "/home/jovyan/work/visualization/graphs/"
                 ) -> None:
        self.directory_path = directory_path
        self.df = df
        self.x = df.select(df.columns[0]).toPandas().values.reshape(-1)
        self.y = df.select(df.columns[1]).toPandas().values.reshape(-1)
        self.file_name = file_name
        self.file_path = directory_path + file_name
        
    def create_graph(self, title: str) -> None:
        """
        Creates 'bar' graph
        """
        try:
            plt.bar(x=self.x, height=self.y, width=0.6)
            plt.title(title, fontsize=16, style='italic')
            plt.xticks(rotation=45, ha="right")
            plt.xlabel(self.df.columns[0])
            plt.ylabel(self.df.columns[1])
            plt.tight_layout()
            plt.savefig(self.file_path)
            plt.close()
            logging.info(f"Successfully saved the image to the file: {self.file_path}")
        except Exception as e:
            logging.error(f"Error in saving the image {self.file_name}: {e}")
            raise e

In [None]:
class PieGraph(Graph):
    """
    Class for creating 'pie' graph
    """
    def __init__(self, 
                 df: DataFrame, 
                 file_name: str,
                 directory_path: str = "/home/jovyan/work/visualization/graphs/"
                 ) -> None:
        self.directory_path = directory_path
        self.df = df
        self.x = df.select(df.columns[0]).toPandas().values.reshape(-1)
        self.y = df.select(df.columns[1]).toPandas().values.reshape(-1)
        self.file_name = file_name
        self.file_path = directory_path + file_name
    
    def create_graph(self, title: str) -> None:
        """
        Creates 'pie' graph
        """
        try:
            colors = plt.get_cmap('Blues')(np.linspace(0.2, 0.9, len(self.x)))
            explode = np.linspace(0.01, len(self.x) / 100, len(self.x))
            plt.pie(x=self.x, 
                    labels=self.y, 
                    colors=colors, 
                    wedgeprops={"linewidth": 1, "edgecolor": "white"}, 
                    explode=explode,
                    autopct='%1.2f%%',
                    center=(0,0),
                    )
            plt.title(title, fontsize=16, style='italic')
            plt.legend(title=self.df.columns[0], bbox_to_anchor=(0.85, -0.2, 0.3, 0.3))
            plt.tight_layout()
            plt.subplots_adjust(left=0.1, right=0.6)
            plt.savefig(self.file_path)
            plt.close()
            logging.info(f"Successfully saved the image to the file: {self.file_path}")
        except Exception as e:
            logging.error(f"Error in saving the image {self.file_name}: {e}")
            raise e

In [None]:
class PlotGraph(Graph):
    """
    Class for creating 'plot' graph
    """
    def __init__(self, 
                 df: DataFrame, 
                 file_name: str,
                 directory_path: str = "/home/jovyan/work/visualization/graphs/"
                 ) -> None:
        self.directory_path = directory_path
        self.df = df
        self.x = df.select(df.columns[0]).toPandas().values.reshape(-1)
        self.y = df.select(df.columns[1]).toPandas().values.reshape(-1)
        self.file_name = file_name
        self.file_path = directory_path + file_name
    
    def create_graph(self, title: str) -> None:
        """
        Creates 'plot' graph
        """
        try:
            plt.plot(self.x, self.y, label=self.df.columns[0], c='lightblue')
            plt.scatter(x=self.x, y=self.y, marker='o', c='blue', s=5)
            plt.title(title, fontsize=16, style='italic')
            plt.xlabel(self.df.columns[0])
            plt.ylabel(self.df.columns[1])
            plt.xticks(np.arange(self.df.agg(min(self.df.columns[0])).collect()[0][0], 
                                 self.df.agg(max(self.df.columns[0])).collect()[0][0],
                                 10))
            plt.tight_layout()
            plt.legend()
            plt.savefig(self.file_path)
            plt.close()
            logging.info(f"Successfully saved the image to the file: {self.file_path}")
        except Exception as e:
            logging.error(f"Error in saving the image {self.file_name}: {e}")
            raise e