In [0]:
import jdc
import missingno as msno
# For plots
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objs as go
import numpy as np
from pyspark.sql import DataFrame
from typing import List, Optional, TYPE_CHECKING

In [0]:
if TYPE_CHECKING:
    import pandas as pd

In [0]:
TARGET_COL: str = 'ConvertedCompYearly'

In [0]:
class Plot:
    """A class for plotting data"""

    def __init__(
        self,
        df: DataFrame,
        col_name: str,
        target_col: str = TARGET_COL,
        title: str = "",
        base_color="lightslategray",
        top_color: str = "crimson",
        top_3_colors: [List[str]] = ["#DC143C", "#EC3257", "#F05C79"],
        *args
    ) -> None:
        """The constructor function of the class.
        :param df: A pyspark.sql.dataframe.DataFrame object.
        :param col_name: The relevant column name.
        :param target col: The target column.
        :param title: The plot title.
        :param base_color: The basic color of the plot.
        :param top_color: The color of the highest value in the plot
        :param top_3_colors: The colors of the top 3 highest values in the plot. If set to an empty list, coloumns colors will be set by base_color param and top_color param.
        """
        self.df = df
        self.col_name = col_name
        self.target_col = target_col
        self.title = title if title else col_name
        self.base_color = base_color
        self.top_color = top_color
        self.top_3_colors = top_3_colors

        self.default_x = list(
            map(lambda row: row[0], self.df.select(self.col_name).collect())
        )
        self.default_y = self.default_x

In [0]:
%%add_to Plot
def get_colors(self, x):
  if self.top_3_colors and len(x) > 3:
    colors = self.top_3_colors + [self.base_color] * (len(x) - 3)
  else:
    colors = [self.base_color,] * len(x)
    colors[0] = self.top_color
  return colors

In [0]:
%%add_to Plot
def plot_hist(self, histnorm: Optional[str] = '', min_val=0, max_val=0,  cumulative_enabled=False) -> None:
  """The function display a histogram plot of column's values. 
  :param histnorm: The histogram type. If set to default - empty string, the hist.
  :return: None - No returned value.
  """
  colors = self.get_colors(x=self.default_x)
  y_title = histnorm if histnorm else 'Count'
  min_val_row = self.df.agg({self.col_name: "min"}).collect()[0]
  max_val_row = self.df.agg({self.col_name: "max"}).collect()[0]
  min_val = min_val if min_val else min_val_row[f"min({self.col_name})"]
  max_val = max_val if max_val else max_val_row[f"max({self.col_name})"]
       
  fig = go.Figure(go.Histogram(
    x=self.default_x,
    histnorm=histnorm,
    cumulative_enabled=cumulative_enabled,
    marker_color=colors,

    xbins=dict( 
        start=min_val, 
        end=max_val
    ),
    autobinx=False
  ))
  
  fig.update_layout(
    title_text=f'<b>{self.title}</b>', 
    title_x=0.5,
    font=dict(family='Arial', size=14), 
    paper_bgcolor='rgb(248, 248, 255)',
    plot_bgcolor='rgb(248, 248, 255)',
    width=1700, 
    yaxis=dict(title=f'<b>{y_title}</b>')
  )         
           
  fig.show()
  return None

In [0]:
%%add_to Plot
def plot_count_bar(self, val_col: str='count') -> None:
  """The function display a bar plot of counted values of a column. 
  :param val_col: The relevant value column.
  :return: None - No returned value.
  """
  x = [row[0] for row in self.df.select(val_col).collect()]
  colors = self.get_colors(x=x)

  fig = go.Figure(go.Bar(
    x=x, 
    y=self.default_y,
    orientation='h',
    marker_color=colors
  ))
  
  fig.update_layout(
    title_text=f'<b>{self.title}</b>', 
    title_x=0.5,
    font=dict(family='Arial', size=14), 
    paper_bgcolor='rgb(248, 248, 255)',
    plot_bgcolor='rgb(248, 248, 255)',
    width=1700, 
    xaxis=dict(title=f'<b>{val_col}</b>')
  )
  fig.show()
  return None

In [0]:
%%add_to Plot
def plot_percentage_bar(self, val_col: str='Percentage (%)') -> None:
  """The function display a bar plot of counted values of a column. 
  :param val_col: The relevant value column.
  :return: None - No returned value.
  """    
  x = [row[0] for row in self.df.select(val_col).collect()]
  colors = self.get_colors(x=x)
  
  fig = go.Figure(go.Bar(
    x=x,
    y=np.full(len(x) , f'{self.col_name}'),
    text=self.default_y,
    orientation="h",
    marker=dict(color=colors, line=dict(color='rgb(248, 248, 249)', width=1)),
  ),
 )
  fig.update_layout(
    title_text=f'<b>{self.title}</b>', 
    title_x=0.5,
    font=dict(family='Arial', size=14), 
    barmode='stack',
    paper_bgcolor='rgb(248, 248, 255)',
    plot_bgcolor='rgb(248, 248, 255)',
    showlegend=False,
    width=1700, 
    height=300,
    xaxis=dict(title=f'<b>{val_col}</b>',ticksuffix='%'),
    yaxis=dict(showticklabels=False),
  )
  fig.show()
  return None

In [0]:
%%add_to Plot
def plot_boxplot(self) -> None:
  """The function display a box plot of categories values of a column. 
  :return: None - No returned value.
  """
  categories_vals = self.df.toPandas().groupby(self.col_name)[self.target_col].apply(list)
  x = categories_vals.index.to_list()
  y = categories_vals.values
  colors = self.get_colors(x=x)
  
  fig = go.Figure()

  for xd, yd, clr in zip(x, y, colors):
          fig.add_trace(go.Box(
              y=yd,
              name=xd,
              boxpoints='outliers',
              jitter=0.5,
              whiskerwidth=0.9,
              marker_color=clr,
              marker_size=2,
              line_width=1)
          )
  
  fig.update_layout(
    title_text=f'<b>{self.title}</b>', 
    title_x=0.5,
    font=dict(family='Arial', size=14), 
    paper_bgcolor='rgb(248, 248, 255)',
    plot_bgcolor='rgb(248, 248, 255)',
#     xaxis=dict(title=f'<b>{title_text}</b>'),
    yaxis=dict(
    autorange=True,
    showgrid=False,
    ),
    margin=dict(
        l=80,
        r=30,
        b=80,
        t=100,
    ),
    showlegend=False
)
  fig.show()
  return None

In [0]:
# Create Q-Q plot of residuals - for checking normality assumption

In [0]:
%%add_to Plot
def plot_qq(self) -> None:
  """The function plots Q-Q plot - normal probablity plot for the residuals.
  :return: None - No returned value.
  """
  unique_vals = list(map(lambda row: row[0], self.df.select(self.col_name).distinct().collect()))
  pandas_df = self.df.withColumn(self.target_col, f.col(self.target_col).cast('int')).toPandas()

  n_cols = 3
  n_rows = len(unique_vals) / n_cols if len(unique_vals) % n_cols == 0 else int(len(unique_vals) / n_cols) + 1

  fig, axes = plt.subplots(n_rows, n_cols)
  _row , _col = 0, -1

  for val in unique_vals:
    if _col >= n_cols - 1:
      _row += 1
      _col = 0
    else:
      _col += 1

    stats.probplot(pandas_df[pandas_df[self.col_name] == val][self.target_col], dist="norm", plot=axes[_row][_col]);
    axes[_row][_col].set_title("Probability Plot - " +  val)

  plt.gcf().set_size_inches(20, 20)
  fig.tight_layout()
  fig.subplots_adjust(top=.9, hspace=.9)
  fig.show()
  return None

In [0]:
from time import sleep

In [0]:
def plot_multiple_qq_plots(
    df: DataFrame, cols_to_group: List[str] = ["Country", "Currency", "Ethnicity"]
) -> DataFrame:
    """The function gets a list of columns as an input and set categories with low freaquency to 'Other' category.
    :param df: A pyspark.sql.dataframe.DataFrame object.
    :param cols_to_group: A list of columns names.
    :return: None - No returned value.
    """
    try:
        for col_name in cols_to_group:
            print("\033[1m" f"{col_name}" + "\033[0m")
            plot = Plot(df=df, col_name=col_name)
            plot.plot_qq()
            sleep(5)
    except NameError as ex:
        print(repr(ex))

In [0]:
%%add_to Plot
def plot_correlation(self, cols: List[str] = [], figsize=(6, 4)) -> None:
  """The function plots pearson correlation between numeric columns
  :param cols: A list of numeric columns.
  :param figsize: A tuple of the figure size (width, hight)
  :returns: None - No returned value
  """
  num_cols = [column[0] for column in self.df.dtypes if column[1] in ('int', 'double')]
  cols = cols if cols else num_cols
  colors = self.get_colors(x=cols)
  reverse_colors = colors[:len(cols) - 1][::-1] # Without the target column that will be dropped
  ax = self.df.select(cols).toPandas().corr()[self.col_name].sort_values().drop(self.col_name).plot(kind='barh', color= reverse_colors, title=self.title, figsize=figsize)
  for p in ax.patches:
    ax.annotate("%.3f" % p.get_width(), (p.get_x() + p.get_width(), p.get_y()), xytext=(5, 7), textcoords='offset points')

  return None

In [0]:
%%add_to Plot
def plot_missing_values(self, figsize=(26, 8)) -> None:
  """The function plots the missing values in a bar chart.
  :param figsize: A tuple of the figure size (width, hight)
  :returns: None - No returned value
  """
  msno.bar(df=self.df.toPandas(),  color=self.top_color, figsize=figsize)

  return None

In [0]:
%%add_to Plot
def plot_pred_vs_actual(self) -> None:
  """The function plots the missing values in a bar chart.
  :param figsize: A tuple of the figure size (width, hight)
  :returns: None - No returned value
  """
  df = self.df.toPandas()
  x, y = self.target_col, self.col_name
  fig = px.scatter(data_frame=df, x=x, y=y, labels={'x': 'actual values', 'y': 'prediction'}, title=f'<b>{self.title}</b>', color_discrete_sequence=[self.top_color],  trendline="ols",  trendline_color_override="blue")

  fig.update_layout(
    title_font_family="Arial", 
    title_font_size=18, 
    title_font_color=self.top_color,
    title_x=0.5,
    title_y=0.9,
    title_xanchor='center',
    title_yanchor='top',
  )
  fig.show()
  return None

In [0]:
def plot_feature_importance(*, df: "pd.DataFrame", colors: Optional[List[str]] = []) -> None:
  """The function display a hbar plot of values of a column. 
  :param DataFrame: pandas DataFrame.
  :param colors: Optional[List[str]]. If left as the default empyt list, the plot will be with "crismon" color. 
  :return: None - No returned value.
  """
  colors = colors if colors else  ["crimson"] * len(df)

  fig = go.Figure(go.Bar(
    x=df.values, 
    y=df.index,
    orientation='h',
    marker_color=colors
  ))

  fig.update_layout(
    title_text=f'<b>Feature Importance</b>', 
    title_x=0.5,
    xaxis_title='Importance',
    yaxis_title='Feature Name',
    font=dict(family='Arial', size=14), 
    paper_bgcolor='rgb(248, 248, 255)',
    plot_bgcolor='rgb(248, 248, 255)',
    height=1000, 
    width=1400, 
  )
  fig.show()
  return None