In [0]:
import pyspark.sql.functions as F
from pyspark.sql.types import DateType
from pyspark.sql.window import Window as W
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pyspark.sql.connect.dataframe import DataFrame
import copy

In [0]:
def null_check(df):
    null_col_names = list()
    for col in df.columns:
        if df.filter(F.col(col).isNull()).count() > 0:
            null_col_names.append(col)
        else:
            continue
    return null_col_names

daily_activity_metrics

In [0]:
df_dam = spark.table("google_fit.silver.daily_activity_metrics")

In [0]:
display(df_dam)

In [0]:
%sql

create schema if not exists google_fit.gold

In [0]:
class CreateDF:
    __base_catalog = "google_fit"
    __base_schema = "silver"
    __base_table = "daily_activity_metrics"
    __base_table_name = f"{__base_catalog}.{__base_schema}.{__base_table}"
    __excepted_cols = ['etl_timestamp', 'file_path']

    @classmethod
    def get_tables(cls, schema_name: str):
        return [
            row["table"]
            for row in spark.sql(
                f"SHOW TABLES IN {cls.__base_catalog}.{schema_name}"
            ).selectExpr(
                f"concat_ws('.', '{cls.__base_catalog}', database, tableName) as table"
            ).filter(""" table not rlike "_sqldf" """)
            .collect()
        ]

    @classmethod
    def from_table(cls, schema_name: str= None, table_name:str = None):
        if schema_name is None and table_name is None:
            ref_table = spark.table(cls.__base_table_name)
        elif(table_name not in cls.get_tables(schema_name) == 0):
            raise Exception("Given table or schema does not exist!")
        else:
            try:
                ref_table = spark.table(f"{cls.__base_catalog}.{schema_name}.{table_name}")
            except Exception as e:
                print(f"Some error occured when reading the referenced table into a DataFrame: {e}")
        cols = ref_table.columns
        return ref_table.select(*[col for col in cols if col not in cls.__excepted_cols])


In [0]:
class DeclarativeAggregations:
    __basic_agg_config = {
        "count": [],
        "avg": [],
        "sum": [],
        "max": [],
        "min": [],
        "lag": []
    }
     
    def __init__(self, df: DataFrame):
        if(isinstance(df, DataFrame)):
            self.df = df
            self.df_trans = None
            self.agg_config =  copy.deepcopy(self.__class__.__basic_agg_config)
            self.entities = None
        else:
            raise Exception("df is not a Spark DataFrame")

    def define_entities(self, entities: str|list[str] = None):
        if(entities is None):
            self.entities = [row['entity'] for row in self.df.select('entity').distinct().collect()]
            return
        if isinstance(entities, str):
            entities_list = [entities]
        else:
            entities_list = entities
        if(set(entities_list).issubset({row['entity'] for row in self.df.select('entity').distinct().collect()})):
            self.entities = list(set(entities_list))
        else:
            raise Exception("Given entity does not exist in the dataframe!")

    @staticmethod
    def cols_checker(df: DataFrame, cols: str|list[str]) -> bool:
        check = lambda cols, df: set(cols).issubset(set(df.columns)) if type(cols) == list else set([cols]).issubset(set(df.columns))
        return check(cols, df)
    
    def build_agg_config(
        self, agg_metric: str, group_by_cols: str | list[str],
        agg_on_col: str, name: str, offset: int = 1, default=None, order_by: str = None
    ):
        if agg_metric not in self.agg_config.keys():
            raise Exception("Given aggregation type is not supported!")
        if self.df_trans is None:
             if not DeclarativeAggregations.cols_checker(self.df, group_by_cols):
                 raise Exception("Given partition columns do not exist in the dataframe!")
             if not DeclarativeAggregations.cols_checker(self.df, agg_on_col):
                 raise Exception("Given agg_on does not exist in the dataframe!")
        else:
            if not (DeclarativeAggregations.cols_checker(self.df_trans, group_by_cols)):
                raise Exception("Given partition columns do not exist in the trans dataframe!")
            if not DeclarativeAggregations.cols_checker(self.df_trans, agg_on_col):
                raise Exception("Given agg_on does not exist in the trans dataframe!")

        config = {
            "group_by_cols": group_by_cols,
            "agg_on_col": agg_on_col,
            "name": name
        }
        if agg_metric in ["lag"]:
            config["offset"] = offset
            config["default"] = default
            config["order_by"] = order_by
        self.agg_config[agg_metric].append(config)
    
    def build_trans_df(self):
        def get_expr_for_agg_metric(agg_metric, derived_col_info):
            if agg_metric == "avg":
                return F.avg(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "max":
                return F.max(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "min":
                return F.min(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "sum":
                return F.sum(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "count":
                return F.count(F.col(derived_col_info["agg_on_col"]))
            elif agg_metric == "lag":
                return F.lag(
                    F.col(derived_col_info["agg_on_col"]),
                    derived_col_info.get("offset", 1),
                    derived_col_info.get("default")
                )
 
        if not self.df_trans:
            self.df_trans = self.df.filter(F.col('entity').isin(self.entities))
        for agg_metric in self.agg_config.keys():
            for derived_col_info in self.agg_config[agg_metric]:
                w_spec = W.partitionBy(derived_col_info['group_by_cols'])
                if agg_metric in ["lag"] and derived_col_info.get("order_by"):
                    w_spec = w_spec.orderBy(derived_col_info["order_by"])
                expr = get_expr_for_agg_metric(agg_metric, derived_col_info)
                self.df_trans = self.df_trans.withColumn(derived_col_info["name"], expr.over(w_spec))

    def clear_agg_config(self):
        self.agg_config = copy.deepcopy(self.__class__.__basic_agg_config)

    def clear_df_trans(self):
        self.df_trans = None

    def current_attributes(self):
        return self.__dict__
    


In [0]:
class LinePlot:
    __base_fig_size= (25, 15)

    def __init__(self, DeclarativeAggregations_obj: DeclarativeAggregations, x, y, fig_size: tuple[int, int]= None):
        if not isinstance(DeclarativeAggregations_obj, DeclarativeAggregations):
            raise Exception("Given object is not a DeclarativeAggregations object!")
        if x not in DeclarativeAggregations_obj.df_trans.columns:
            raise Exception("Given x column does not exist in the dataframe!")
        if y not in DeclarativeAggregations_obj.df_trans.columns:
            raise Exception("Given y column does not exist in the dataframe!")
        self.df_pd = DeclarativeAggregations_obj.df_trans.select([x, y, 'entity']).toPandas()
        self.x = x
        self.y = y
        self.fig_size = fig_size if fig_size is not None else LinePlot.__base_fig_size

    def plot(self, fig_size: tuple[int, int]= None):
        fig, ax = plt.subplots(figsize= self.fig_size)
        sns.lineplot(x= self.x, y= self.y, data=self.df_pd, ax= ax, hue= 'entity')
        plt.tight_layout()



In [0]:
df_dam = CreateDF.from_table()

In [0]:
df_dam_weekly = (
    df_dam.withColumn("week", F.date_trunc("week", F.col("date")).cast(DateType()))
          # .withColumn("prev_week", F.date_sub(F.col("week"), 7))
        # .withColumn("month", F.date_trunc("month", F.col("date")).cast(DateType()))
         .selectExpr("entity", "date", "week", "heart_points as daily_heart_points")
)

In [0]:
da_dam_weekly = DeclarativeAggregations(df_dam_weekly)

In [0]:
da_dam_weekly.define_entities()

In [0]:
da_dam_weekly.current_attributes()

In [0]:
da_dam_weekly.df.display()

In [0]:
da_dam_weekly.build_agg_config(agg_metric="sum", group_by_cols=["entity", "week"], agg_on_col="daily_heart_points", name= "current_week_heart_points")

In [0]:
da_dam_weekly.build_agg_config(agg_metric="lag", group_by_cols=["entity"], agg_on_col="current_week_heart_points", order_by= "week", offset= 7, default= 0, name= "prev_week_heart_points")

In [0]:
da_dam_weekly.build_trans_df()
da_dam_weekly.df_trans.display()

In [0]:
da_dam_weekly.clear_agg_config()
da_dam_weekly.clear_df_trans()

In [0]:
LinePlot(da_dam_weekly, "week", "current_week_heart_points").plot()

In [0]:
        # for window_part_col in self.window_part_cols:
        #     w_spec = W.partitionBy(window_part_col)
        #     if(self.agg_on_cols is None):
        #         self.df_trans = self.df_trans.withColumn(f"count_all_rows_over_{"_and_".join(window_part_col)}", F.count(F.lit(1)).over(w_spec))
        #     else:
        #         for agg_on_col in self.agg_on_cols:
        #             for agg_metric in self.agg_metrics:
        #                 expr = get_expr_for_agg_metric(agg_metric, agg_on_col)
        #                 if(agg_metric == "lag" or agg_metric == "rank"):
        #                     w_spec = W.partitionBy(window_part_col).orderBy(self.w_order_by_cols)
        #                     self.df_trans = self.df_trans.withColumn(
        #                         f"{agg_metric}_of_{agg_on_col}_over_{"_and_".join(window_part_col)}_ord_by_{"_and_".join(self.w_order_by_cols)}",
        #                         expr.over(w_spec)
        #                     )
        #                 else:
        #                     self.df_trans = self.df_trans.withColumn(
        #                         f"{agg_metric}_of_{agg_on_col}_over_{"_and_".join(window_part_col)}",
        #                         expr.over(w_spec)
        #                     )      

In [0]:
spark.sql("show tables in google_fit.silver ").display()

In [0]:
type(p.df)

In [0]:
null_check(df_dam)

In [0]:
df_dam.columns

In [0]:
df_dam_test = (
    df_dam.select(['entity', 'date', 'distance_m', 'step_count', 'heart_points', 'heart_minutes', 'move_minutes_count'])
        .withColumn('week', F.next_day(F.col('date'), 'sunday'))
        .withColumn('month', F.date_trunc('month', F.col('date')).cast(DateType()))
        .orderBy(F.col('date').desc())
)

In [0]:
display(df_dam_test)

In [0]:
df_dam_test_agg = add_agg_metrics(df= df_dam_test)

In [0]:
display(df_dam_test_agg)

In [0]:
def plot_linegraphs(df_pd, figsize=(60, 30), graphs= None, category= 'entity'):
    if isinstance(df_pd, DataFrame):
        df_pd = df_pd.toPandas()
    if graphs is None:
        graphs = {
            'date': [],
            'week': [],
            'month': []
        }
        cols = list(df_pd.columns)
        for col in cols:
            if 'daily' in col:
                graphs['date'].append(col)
            elif 'weekly' in col:
                graphs['week'].append(col)
            elif 'monthly' in col:
                graphs['month'].append(col)
    fig, axs = plt.subplots(ncols=max([len(graphs[k]) for k in graphs.keys()]), nrows= len(graphs) , figsize= figsize)
    for xi, x in enumerate(graphs.keys()):
        for yi, y in enumerate(graphs[x]):
            sns.lineplot(data= df_pd, x= df_pd[x], y= df_pd[y], hue= category, ax= axs[xi, yi])
            axs[xi, yi].set_title(f"{y}_by_{x}")
    fig.tight_layout()


In [0]:
df_dam_test_agg_pd = df_dam_test_agg.toPandas()

In [0]:
plot_linegraphs(df_dam_test_agg_pd)

activities

In [0]:
df_act = spark.table("google_fit.silver.activities")

In [0]:
display(df_act)

In [0]:
null_check(df_act)

all_sessions

In [0]:
df_ses = spark.table("google_fit.silver.all_sessions")

In [0]:
display(df_ses)

In [0]:
null_check(df_ses)

In [0]:
display(df_ses.groupBy('entity').count())