In [0]:
import pyspark.sql.functions as F
from pyspark.sql.types import DateType
from pyspark.sql.window import Window as W
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pyspark.sql.connect.dataframe import DataFrame
import copy
import re
from datetime import datetime
import os

In [0]:
class CreateDF:
    __base_catalog = "google_fit"
    __base_schema = "silver"
    __base_table = "daily_activity_metrics"
    __base_table_name = f"{__base_catalog}.{__base_schema}.{__base_table}"
    __excepted_cols = ['etl_timestamp', 'file_path']

    @classmethod
    def get_tables(cls, schema_name: str):
        return [
            row["table"]
            for row in spark.sql(
                f"SHOW TABLES IN {cls.__base_catalog}.{schema_name}"
            ).selectExpr(
                f"concat_ws('.', '{cls.__base_catalog}', database, tableName) as table"
            ).filter(""" table not rlike "_sqldf" """)
            .collect()
        ]

    @classmethod
    def from_table(cls, schema_name: str= None, table_name:str = None):
        if schema_name is None and table_name is None:
            ref_table = spark.table(cls.__base_table_name)
        elif(table_name not in cls.get_tables(schema_name) == 0):
            raise Exception("Given table or schema does not exist!")
        else:
            try:
                ref_table = spark.table(f"{cls.__base_catalog}.{schema_name}.{table_name}")
            except Exception as e:
                print(f"Some error occured when reading the referenced table into a DataFrame: {e}")
        cols = ref_table.columns
        return ref_table.select(*[col for col in cols if col not in cls.__excepted_cols])


In [0]:
class DeclarativeAggregations:
    __basic_agg_config = {
        "count": [],
        "avg": [],
        "sum": [],
        "max": [],
        "min": [],
        "lag": [],
        "rank": []
    }
     
    def __init__(self, df: DataFrame):
        if(isinstance(df, DataFrame)):
            self.df = df
            self.df_trans = None
            self.agg_config =  copy.deepcopy(self.__class__.__basic_agg_config)
            self.entities = None
        else:
            raise Exception("df is not a Spark DataFrame")

    def define_entities(self, entities: str|list[str] = None):
        if(entities is None):
            self.entities = [row['entity'] for row in self.df.select('entity').distinct().collect()]
            return
        if isinstance(entities, str):
            entities_list = [entities]
        else:
            entities_list = entities
        if(set(entities_list).issubset({row['entity'] for row in self.df.select('entity').distinct().collect()})):
            self.entities = list(set(entities_list))
        else:
            raise Exception("Given entity does not exist in the dataframe!")

    @staticmethod
    def cols_checker(df: DataFrame, cols: str|list[str]) -> bool:
        check = lambda cols, df: set(cols).issubset(set(df.columns)) if type(cols) == list else set([cols]).issubset(set(df.columns))
        return check(cols, df)
    
    def build_agg_config(
        self, agg_metric: str, group_by_cols: str | list[str], name: str, 
        agg_on_col: str= None, offset: int = 1, default=None, order_by: str = None
    ):
        if agg_metric not in self.agg_config.keys():
            raise Exception("Given aggregation type is not supported!")
        if self.df_trans is None:
            if not DeclarativeAggregations.cols_checker(self.df, group_by_cols):
                raise Exception("Given partition columns do not exist in the dataframe!")
            if agg_metric not in ["rank"]:
                if not DeclarativeAggregations.cols_checker(self.df, agg_on_col):
                    raise Exception("Given agg_on does not exist in the dataframe!")
        else:
            if not (DeclarativeAggregations.cols_checker(self.df_trans, group_by_cols)):
                raise Exception("Given partition columns do not exist in the trans dataframe!")
            if agg_metric not in ["rank"]:
                if not DeclarativeAggregations.cols_checker(self.df_trans, agg_on_col):
                    raise Exception("Given agg_on does not exist in the trans dataframe!")

        config = {
            "group_by_cols": group_by_cols,
            "name": name
        }
        if agg_metric in ["rank"]:
            if not order_by:
                raise Exception("Order by is required for rank!")
            config["order_by"] = order_by
        else:
            if agg_metric in ["lag"]:
                if not order_by:
                    raise Exception("Order by is required for lag!")
                config["offset"] = offset
                config["default"] = default
                config["order_by"] = order_by
            config['agg_on_col'] = agg_on_col
        self.agg_config[agg_metric].append(config)

    
    def build_trans_df(self):
        def get_expr_for_agg_metric(agg_metric, derived_col_info):
            if agg_metric == "avg":
                return F.avg(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "max":
                return F.max(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "min":
                return F.min(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "sum":
                return F.sum(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "count":
                return F.count(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "lag":
                return F.lag(
                    F.col(derived_col_info["agg_on_col"]),
                    derived_col_info.get("offset", 1),
                    derived_col_info.get("default")
                )
            elif agg_metric == "rank":
                return F.dense_rank()
 
        if not self.df_trans:
            self.df_trans = self.df.filter(F.col('entity').isin(self.entities))
        for agg_metric in self.agg_config.keys():
            for derived_col_info in self.agg_config[agg_metric]:
                w_spec = W.partitionBy(derived_col_info['group_by_cols'])
                if agg_metric in ["lag", "rank"] and derived_col_info.get("order_by"):
                    w_spec = w_spec.orderBy(derived_col_info["order_by"])
                expr = get_expr_for_agg_metric(agg_metric, derived_col_info)
                self.df_trans = self.df_trans.withColumn(derived_col_info["name"], expr.over(w_spec))

    def add_comparison_col_percent(self, prev_col: str, curr_col: str, comp_col_name: str):
        if self.df_trans is None:
            raise Exception("Trans dataframe is not built yet!")
        if prev_col not in self.df_trans.columns or curr_col not in self.df_trans.columns:
            raise Exception("Given col does not exist in the trans dataframe!")
        schema_trans = self.df_trans.schema
        if schema_trans[prev_col].dataType != schema_trans[curr_col].dataType:
            raise Exception("Given cols are not of the same type!")
        self.df_trans = self.df_trans.withColumn(comp_col_name, F.when(F.col(prev_col) != F.lit(0), F.round((F.col(curr_col) - F.col(prev_col))*100/F.col(prev_col), 3)).otherwise(F.lit(0)))
    
    def clear_agg_config(self):
        self.agg_config = copy.deepcopy(self.__class__.__basic_agg_config)

    def clear_df_trans(self):
        self.df_trans = None

    def current_attributes(self):
        return self.__dict__
    


In [0]:
class LinePlot:
    __base_fig_size= (25, 15)


    __axis_is_a_comparison_col = lambda y: True if re.search(r'percent', y) else None
    __get_regex_for_prev_col = lambda y: False if re.search(r"_(.+)_change", y) is None else re.search(r"_(.+)_change", y).group(1)
    __map_percent_col_to_previous_col = lambda regex, col: False if re.search(fr"(prev.*{regex})", col) is None else col


    __base_save_path = "/Volumes/google_fit/gold/weekly_plots_activity_metrics"

    def __init__(self, DeclarativeAggregations_obj: DeclarativeAggregations, x, y, fig_size: tuple[int, int]= None, save_path: str= None):
        if not isinstance(DeclarativeAggregations_obj, DeclarativeAggregations):
            raise Exception("Given object is not a DeclarativeAggregations object!")
        if x not in DeclarativeAggregations_obj.df_trans.columns:
            raise Exception("Given x column does not exist in the dataframe!")
        if y not in DeclarativeAggregations_obj.df_trans.columns:
            raise Exception("Given y column does not exist in the dataframe!")
        if LinePlot.__axis_is_a_comparison_col(y):
            regex = LinePlot.__get_regex_for_prev_col(y)
            for col in DeclarativeAggregations_obj.df_trans.columns:
                prev_col = LinePlot.__map_percent_col_to_previous_col(regex, col)
                if prev_col:
                    break
            self.df_pd = DeclarativeAggregations_obj.df_trans.filter(F.col(prev_col).isNotNull()).select([x, y, 'entity']).toPandas()
        else:
            self.df_pd = DeclarativeAggregations_obj.df_trans.select([x, y, 'entity']).toPandas()
        self.x = x
        self.y = y
        self.save_path = save_path if save_path is not None else f"{LinePlot.__base_save_path}/{self.y}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
        self.fig_size = fig_size if fig_size is not None else LinePlot.__base_fig_size

    def plot(self, fig_size: tuple[int, int]= None):
        fig, ax = plt.subplots(figsize= self.fig_size)
        sns.lineplot(x= self.x, y= self.y, data=self.df_pd, ax= ax, hue= 'entity')
        plt.tight_layout()

