In [12]:
import pandas as pd
import pyspark 
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

spark

In [204]:
import numpy as np
# create test data
from pandas import DataFrame
df = DataFrame({
    'id': ['a'] * 6 + ['b'] * 7, 
    'year': [2022, 2022, 2022, 2023, 2023, 2023, 2020, 2021, 2021, 2021, 2021, 2021, 2021],
    'month': [10, 11, 12, 1, 2, 3, 12, 1, 4, 5, 8, 9, 12], 
    'balance': (round(i, 0) for i in np.random.normal(1000, 10, size=13))
})


df = DataFrame({
    'id': ['a', 'a', 'b', 'b'],
    'year': [2022, 2023, 2022, 2022],
    'month': [6, 3, 1, 2],
    'balance': [10, 20, 0, 0]
})
table = spark.createDataFrame(df)
table.show()


  for column, series in pdf.iteritems():


+---+----+-----+-------+
| id|year|month|balance|
+---+----+-----+-------+
|  a|2022|    6|     10|
|  a|2023|    3|     20|
|  b|2022|    1|      0|
|  b|2022|    2|      0|
+---+----+-----+-------+



In [205]:
table = table.withColumn('year_month', F.expr('make_date(year, month, 1)'))

table.show()

+---+----+-----+-------+----------+
| id|year|month|balance|year_month|
+---+----+-----+-------+----------+
|  a|2022|    6|     10|2022-06-01|
|  a|2023|    3|     20|2023-03-01|
|  b|2022|    1|      0|2022-01-01|
|  b|2022|    2|      0|2022-02-01|
+---+----+-----+-------+----------+



In [193]:
all_dates_df.join(table, [ID_COL, DATE_COL], 'left').show()

+---+----------+----+-----+-------+
| id|year_month|year|month|balance|
+---+----------+----+-----+-------+
|  a|2022-10-01|2022|   10|  997.0|
|  a|2022-10-31|null| null|   null|
|  a|2022-12-01|2022|   12| 1000.0|
|  a|2022-12-31|null| null|   null|
|  a|2023-01-31|null| null|   null|
|  a|2023-03-01|2023|    3| 1013.0|
|  b|2020-12-01|2020|   12| 1002.0|
|  b|2020-12-31|null| null|   null|
|  b|2021-01-31|null| null|   null|
|  b|2021-03-01|null| null|   null|
|  b|2021-03-31|null| null|   null|
|  b|2021-05-01|2021|    5|  982.0|
|  b|2021-05-31|null| null|   null|
|  b|2021-07-01|null| null|   null|
|  b|2021-07-31|null| null|   null|
|  b|2021-08-31|null| null|   null|
|  b|2021-10-01|null| null|   null|
|  b|2021-10-31|null| null|   null|
|  b|2021-12-01|2021|   12|  990.0|
+---+----------+----+-----+-------+



In [206]:
from typing import Optional, Union
from datetime import datetime
from numpy import datetime64

class BackFillJoin:
    '''Class for joining and filling time-series data

    Class accepts a pyspark table and creates a consistent spine of year-month entries per-id
    ending on either the maximum date per-id OR the ref-date (if specified)

    Args:
        id_col: identifies unique entities 
        date_col: string to identify column used as date, must of DateType
        year_month_date: if set to true, transform() expects table with two columns named "year", and "date"
        backfill: if set to true, returned data will be backfilled with the last non-null value
        ref_date: if specified, time-series will be created up to this date point per-id

    Raises:
        ValueError: if date_col is specified and year_month_date is not False
    
    '''
    def __init__(self, 
        id_col:str,
        date_col: str = None,
        year_month_date: bool = False,
        backfill:bool= False, 
        ref_date: Optional[Union[str, datetime, datetime64]] = None
    ):

        if year_month_date and date_col is not None:
            raise ValueError('if year_month date is specified, cannot specify separate date column, as dataframe is expected to have a column named "year" and "month"')

        if ref_date and type(ref_date) not in (datetime64, datetime):
            raise ValueError('ref_date must be np.datetime64 object or python datetime object')

        self.id_col = id_col
        self.backfill = backfill
        self.ref_date = BackFillJoin._convert_ref_date(ref_date)
        self.date_col = date_col
        self.year_month_date = year_month_date

    @staticmethod
    def _convert_ref_date(ref_date):
        if type(ref_date) is np.datetime64:
            return np.datetime_as_string(ref_date, unit='D')
        if type(ref_date) is datetime:
            return ref_date.strftime('%Y-%m-%d')
            

    def _create_year_month(self, table):
        if self.year_month_date:
            expr = F.expr('make_date(year, month, 1)')
        else:
            expr = F.expr(f'make_date(year({self.date_col}), month({self.date_col}), 1)')

        return table.withColumn('year_month', expr)

    def _create_all_dates(self, table):
        grouped = table.groupBy(self.id_col).agg(
                        F.max('year_month').alias("max_date"),
                        F.min('year_month').alias("min_date"))

        grouped.show()
        if self.ref_date:
            all_dates = (grouped
                            .withColumn('ref_date', F.lit(self.ref_date))
                            .select(
                                self.id_col, 
                                F.expr(f"sequence(to_date(min_date), to_date(ref_date), interval 1 month)").alias('year_month'))
                            .withColumn('year_month', F.explode('year_month')))


        else:
            all_dates = (grouped
                            .select(
                                self.id_col, 
                                F.expr("sequence(to_date(min_date), to_date(max_date), interval 1 month)").alias('year_month'))
                            .withColumn('year_month', F.explode('year_month')))
            
        return all_dates

    def transform(self, table):
        table = self._create_year_month(table)
        all_dates = self._create_all_dates(table)

        all_dates.show(truncate=False)

        if self.backfill:
            w = Window.partitionBy(self.id_col).orderBy('year_month')
            return (all_dates
                .join(table, [self.id_col, 'year_month'], "left")
                .select(self.id_col, 'year_month', *[F.last(F.col(c), ignorenulls=True).over(w).alias(c) for c in df.columns if c not in (self.id_col, 'year_month')] )
                .withColumn('year', F.expr('year(year_month)'))
                .withColumn('month', F.expr('month(year_month)')))

        return (all_dates
                .join(table, [self.id_col, 'year_month'], 'left')
                .withColumn('year', F.expr('year(year_month)'))
                .withColumn('month', F.expr('month(year_month)')))

In [211]:
import numpy as np
s = BackFillJoin(id_col='id',date_col='year_month', backfill=True)
table.drop(*['year', 'month']).show()

+---+-------+----------+
| id|balance|year_month|
+---+-------+----------+
|  a|     10|2022-06-01|
|  a|     20|2023-03-01|
|  b|      0|2022-01-01|
|  b|      0|2022-02-01|
+---+-------+----------+



In [212]:
s.transform(table).show(40)

+---+----------+----------+
| id|  max_date|  min_date|
+---+----------+----------+
|  a|2023-03-01|2022-06-01|
|  b|2022-02-01|2022-01-01|
+---+----------+----------+

+---+----------+
|id |year_month|
+---+----------+
|a  |2022-06-01|
|a  |2022-07-01|
|a  |2022-08-01|
|a  |2022-09-01|
|a  |2022-10-01|
|a  |2022-11-01|
|a  |2022-12-01|
|a  |2023-01-01|
|a  |2023-02-01|
|a  |2023-03-01|
|b  |2022-01-01|
|b  |2022-02-01|
+---+----------+

+---+----------+----+-----+-------+
| id|year_month|year|month|balance|
+---+----------+----+-----+-------+
|  a|2022-06-01|2022|    6|     10|
|  a|2022-07-01|2022|    7|     10|
|  a|2022-08-01|2022|    8|     10|
|  a|2022-09-01|2022|    9|     10|
|  a|2022-10-01|2022|   10|     10|
|  a|2022-11-01|2022|   11|     10|
|  a|2022-12-01|2022|   12|     10|
|  a|2023-01-01|2023|    1|     10|
|  a|2023-02-01|2023|    2|     10|
|  a|2023-03-01|2023|    3|     20|
|  b|2022-01-01|2022|    1|      0|
|  b|2022-02-01|2022|    2|      0|
+---+----------+--

'2024-01-01'