In [1]:
! pip install pyspark[sql]

Collecting pyspark[sql]
  Downloading pyspark-3.2.0.tar.gz (281.3 MB)
[K     |████████████████████████████████| 281.3 MB 19 kB/s 
[?25hCollecting py4j==0.10.9.2
  Downloading py4j-0.10.9.2-py2.py3-none-any.whl (198 kB)
[K     |████████████████████████████████| 198 kB 42.5 MB/s 
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.2.0-py2.py3-none-any.whl size=281805912 sha256=a7fd93ed15bcaef27a38b498735805baa0853bdbd89c850d7853738eb13312dd
  Stored in directory: /root/.cache/pip/wheels/0b/de/d2/9be5d59d7331c6c2a7c1b6d1a4f463ce107332b1ecd4e80718
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.2 pyspark-3.2.0


In [2]:
from typing import Optional

# from pyspark.sql.session import SparkSession

from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    NumericType,
    StringType,
    StructField,
    StructType,
)

In [3]:
def _create_entity(dt: int, *, id: str, version="0.0.1", **attrs):
        return {
            "_dt": _dt(dt),
            "id": id,
            "version": version,
            **attrs,
        }

def _create_schema(**fields):
    schema = StructType()
    for k, v in fields.items():
        schema.add(k, v)
    return schema


def _v(value: Optional[str]) -> dict:
    return {"value": value}


def _dt(secs: int) -> str:
    return f"2020-01-01T00:00:{secs:02}.0000"


def _ns(*attrs):
    ns = StructType()
    for key in attrs:
        ns.add(key, field_type)
    return ns


field_type = StructType(
    [
        # Value
        StructField("value", StringType(), nullable=True)
        # Any additional metadata
    ]
)

EVENT_SCHEMA = {
    "_dt": StringType(),
    "version": StringType(),
    "id": StringType(),
}

COMPACTED_SCHEMA = {
    "version": StringType(),
    "id": StringType(),
    "date_created": StringType(),
    "date_updated": StringType(),
}

UNION_SCHEMA = {
    "_dt": StringType(),
    "_compacted": BooleanType(),
    "date_created": StringType(),
    "date_updated": StringType(),
    "id": StringType(),
    "version": StringType(),
}

## Compactor

In [4]:
from datetime import datetime
from typing import NewType, Union, overload

class Action:
    T = NewType("T", str)
    DELETE = T("delete")
    REPLACE = T("replace")


def recursively_merge_attributes(base, update):
    if update is None:
        # Updates are None if the event didn't set that attribute
        return base
    if not isinstance(base, dict):
        # Base is None if it's a part of the skeleton
        # If base is another non-dict value we can't merge, so the update wins
        # can't just `return update` here, need to handle possible `c_action` below
        base = {}
    if not isinstance(update, dict):
        # can't merge so update wins
        return update

    merged = {}

    if update.get("c_action") == Action.REPLACE:
        base = {}

    # NOTE: top level deletion is enabled
    if update.get("c_action") == Action.DELETE:
        return {}

    # NOTE: top level expiry is enabled
    exipire_dt = update.get("c_expire_dt", "9999-12-31T23:59:59.999999")
    if exipire_dt and datetime.fromisoformat(exipire_dt) <= datetime.now():
        return {}

    for key in (base.keys() | update.keys()) - {"c_action"}:
        nested_update = update.get(key)

        if isinstance(nested_update, dict):
            to_delete = nested_update.get("c_action") == Action.DELETE
            exipire_dt = nested_update.get("c_expire_dt", "9999-12-31T23:59:59.999999")
            to_expire = exipire_dt and datetime.fromisoformat(exipire_dt) <= datetime.now()

            if to_delete or to_expire:
                continue

        merged[key] = recursively_merge_attributes(base.get(key), nested_update)
    return merged


In [5]:
import logging
from typing import Dict, List, Optional, Tuple, Union

from pyspark.sql import DataFrame, Row, functions as F
from pyspark.sql.session import SparkSession
from pyspark.sql.types import ArrayType, StringType, StructType

log = logging.getLogger(__name__)

EntityAttributesItem = Union[str, None, dict]
EntityAttributes = Dict[str, EntityAttributesItem]
AttributeSkeleton = Dict[str, Union[None, dict]]

spark_session = SparkSession.builder.getOrCreate()
# spark = SparkSession.builder.getOrCreate()

def get_attributes_skeleton(schema: StructType) -> AttributeSkeleton:
    """
    Constructs an empty dictionary that can contain every possible
    attribute in schema.
    Used to simplify merging new and old schemae together
    """
    skeleton: AttributeSkeleton = {}
    for field in schema:
        if isinstance(field.dataType, StructType):
            skeleton[field.name] = get_attributes_skeleton(field.dataType)
        else:
            skeleton[field.name] = None
    return skeleton


def merge_schemas(base: StructType, update: StructType) -> Tuple[StructType, bool]:
    """
    Creates a new schema that contains all the attributes in both passed
    schemas, so we can easily merge events into compacted entities.

    There is no reasonable way to preserve field ordering while doing this,
    so we just alphabetise the fields so that everything is consistent.
    It also looks nicer when debugging.

    Field order is also important when mutating dataframes, which we
    currently don't do, but if we start doing it having everything
    alphabetical will help with that.

    Also, it returns a convenience sentinel to let you know if the compacted
    entities need to be updated to add a new field to them.
    """

    base_keys = set(base.fieldNames())
    update_keys = set(update.fieldNames())
    base_has_new_keys = bool(update_keys - base_keys)

    ret = StructType()

    for key in sorted(base_keys | update_keys):
        base_type = base[key].dataType if key in base_keys else None
        update_type = update[key].dataType if key in update_keys else None

        if base_type is None or update_type is None:
            new_type = base_type or update_type
            ret.add(key, new_type)
            continue

        if type(base_type) != type(update_type):
            raise RuntimeError(
                "Trying to merge invalid schemas: "
                f"Got a {type(update_type)} {key} but expected a {type(base_type)}!"
            )

        if isinstance(update_type, StructType):
            update_type, new_keys = merge_schemas(base_type, update_type)
            if new_keys:
                base_has_new_keys = True

        if isinstance(update_type, ArrayType):
            update_sub_type = update_type.elementType
            base_sub_type = base_type.elementType

            if type(base_sub_type) != type(update_sub_type):
                raise RuntimeError(
                    "Trying to merge invalid schemas: "
                    f"Got a {type(update_sub_type)} {key} but expected a {type(base_sub_type)}!"
                )

            if isinstance(update_sub_type, StructType):
                update_sub_type, new_keys = merge_schemas(base_sub_type, update_sub_type)
                update_type = ArrayType(update_sub_type)
                if new_keys:
                    base_has_new_keys = True

            # TODO: Arrays of arrays oh no this needs a huge refactor
            assert not isinstance(update_sub_type, ArrayType)

        ret.add(key, update_type)

    return ret, base_has_new_keys


def extract_entities(data: DataFrame) -> DataFrame:
    """
    Extracts all the entities from the passed events, maintaining order.
    Does not de-duplicate
    """

    entity_columns = [f"`{col}`" for col in data.columns if not col.startswith("_")]

    return data.select(F.col("`_dt`"), F.col("`_version`").alias("version"), *entity_columns)


def create_entity_union(
    spark_session: SparkSession, events: DataFrame, compacted: Optional[DataFrame]
) -> DataFrame:
    """
    Joins the events and compacted data into one giant un-de-duplicated
    dataframe so they can be more easily compacted.
    """

    if "date_created" not in events.columns:
        events = events.withColumn("date_created", F.lit(None).cast(StringType()))
    events = events.withColumn("date_updated", F.col("_dt"))
    events = events.withColumn("_compacted", F.lit(False))

    # Counting a dataframe can be extremely slow, so pull the first item and
    #  see if it exists.
    # This is merely fairly slow which is a huge improvement over .count()!
    if not compacted or compacted.first() is None:
        return events

    # Compacted entities don't have a datetime since they're not events, so
    # make sure they have a datetime of "0" that will always sort before any
    # event to avoid suprises
    compacted = compacted.withColumn("_dt", F.lit("0"))

    # Make it so we can find compacted entities easily
    compacted = compacted.withColumn("_compacted", F.lit(True))

    schema, update_compacted = merge_schemas(compacted.schema, events.schema)
    skeleton = get_attributes_skeleton(schema)

    def fix_schema(row: Row) -> EntityAttributesItem:
        return recursively_merge_attributes(skeleton, row.asDict(recursive=True))

    # Add all the missing fields to the events so we can merge them with the compacted
    events = spark_session.createDataFrame(events.rdd.map(fix_schema), schema)
    # Make sure compacted follows same schema as events
    compacted = spark_session.createDataFrame(compacted.rdd.map(fix_schema), schema)

    return compacted.unionByName(events)


def do_compaction(data: Tuple[str, List[dict]]) -> dict:
    """
    Accepts a list of entities, expected to be a compacted base & updated
    events (in arbitrary order)

    The first argument is a tuple beacuse of how grouping works in RDDs. We
    could have a tiny map step before this one that removes the tuple, but I
    thought it would be more efficient to remove it here.
    """

    # Get rid of the grouping tuple
    events = data[1]

    # # NOTE: events: [{compacted}, {new_events}]
    # # Short circuit for entities with no events
    # if len(events) == 1 and events[0]["_compacted"]:
    #     return events[0]

    events.sort(key=lambda a: a["_dt"])

    base = events[0]  # compacted event
    ret: dict = {}

    if base["_compacted"]:
        if base["date_created"] is not None:
            # Ensure no events can overwrite date_created
            # NOTE: overwrite `date_created` in the end
            events.append({"date_created": base["date_created"]})
    else:
        # Invent a date_created for the new entity, though in such a way that
        # it can be easily overwritten by events
        ret["date_created"] = base["_dt"]

    for update in events:
        
        ret = recursively_merge_attributes(ret, update)
        print(ret)

    return ret


def _map_value(value: Row) -> dict:
    """
    Turns the pyspark rows into something a bit more easy to work with
    """
    return value.asDict(recursive=True)


Combiner = List[dict]


# The following functions are used by combineByKey to do distributed map/reduce
#  operations across multiple partitions.
# For more info, check out the docs:
# https://spark.apache.org/docs//2.4.3/api/python/pyspark.html#pyspark.RDD.combineByKey
def create_combiner(value: Row) -> Combiner:
    """
    First map operation done per partition.
    In our case, we turn the value (a row) into a dict and then add it to a list
    """
    return [_map_value(value)]


def merge_value(combiner: Combiner, value: Row) -> Combiner:
    """
    Combined map/reduce operation on values in a partition.

    `combiner` is the output of create_combiner and `value` is a raw unmapped value.

    We convert it to a dict also, and then append it to the list.

    We can't merge values yet as this is done in an arbitrary order and our
    code is order-dependant, so we're just building a list here.
    """
    combiner.append(_map_value(value))
    return combiner


def merge_combiners(a: Combiner, b: Combiner) -> Combiner:
    """
    Reduce operation between partitions.

    In this case, we just smack the two lists together, though we could do
    other operations if we wanted.

    Note that PySpark recommends modifying the first argument rather than
    creating a new list to avoid memory operations.
    """
    a.extend(b)
    return a


def compact_entities(spark_session: SparkSession, union: DataFrame) -> DataFrame:
    """
    Compacts all the entities in the passed dataframe, grouping by ID and compacting in date order.
    """

    # TODO: Is it worth spending the time to figure out if there are any
    #  duplicate entries in the first place?
    # We can skip an entire map/reduce step if nothing needs compacting...
    # (Also presumably once we've done that we can partition better??)

    merged = (
        union.rdd
        # Group everything by ID.
        # Note that combineByKey only works if the dataset is two columns, so
        # we need to map it first
        .map(lambda entity: (entity["id"], entity)).combineByKey(
            create_combiner, merge_value, merge_combiners
        )
        # Turn the groups back into entities
        .map(do_compaction)
    )

    return spark_session.createDataFrame(merged, union.schema).drop("_dt", "_compacted").dropna("any", subset="id")


## Examples

### do_compaction

In [6]:
events = [
            {"_dt": _dt(2), "_compacted": False, "date_created": None, "foo": None, "bar": "3", "c_action": "delete"},
            {"_dt": _dt(0), "_compacted": True, "date_created": _dt(0), "foo": "1", "bar": "1"},
            {"_dt": _dt(1), "_compacted": False, "date_created": None, "foo": "2", "bar": "2"},
        ]
compacted = do_compaction(("", events))

{'foo': '1', '_compacted': True, '_dt': '2020-01-01T00:00:00.0000', 'date_created': '2020-01-01T00:00:00.0000', 'bar': '1'}
{'foo': '2', '_compacted': False, '_dt': '2020-01-01T00:00:01.0000', 'date_created': '2020-01-01T00:00:00.0000', 'bar': '2'}
{}
{'date_created': '2020-01-01T00:00:00.0000'}


In [7]:
compacted

{'date_created': '2020-01-01T00:00:00.0000'}

### compact_entities

In [8]:
union = spark_session.createDataFrame(
    [
        dict(
            _dt=_dt(4),
            id="foo",
            version="1.0.0",
            date_updated=_dt(4),
            live={"foo": _v("3")},
        ),
        dict(
            _dt=_dt(3),
            id="foo",
            version="1.0.0",
            date_updated=_dt(3),
            live={"foo": _v("2")},
        ),
        dict(
            _dt="0",
            _compacted=True,
            id="foo",
            version="1.0.0",
            date_updated=_dt(2),
            date_created=_dt(1),
            live={"foo": _v("1")},
        ),
        dict(
            _dt="0",
            _compacted=True,
            id="chic",
            version="1.0.0",
            date_updated=_dt(2),
            date_created=_dt(1),
            c_action = "delete",
            # live={"foo": _v("1"), "c_expire_dt": "0001-01-01T00:00:00", "core": {"food": "chic", "c_expire_dt": "0001-01-01T00:00:00"}},
            live={"foo": _v("1"), "core": {"food": "chic"}},
        ),
    ],
    _create_schema(**UNION_SCHEMA, live=_ns("foo")),
)

compacted = compact_entities(spark_session, union)

In [9]:
union.collect()

[Row(_dt='2020-01-01T00:00:04.0000', _compacted=None, date_created=None, date_updated='2020-01-01T00:00:04.0000', id='foo', version='1.0.0', live=Row(foo=Row(value='3'))),
 Row(_dt='2020-01-01T00:00:03.0000', _compacted=None, date_created=None, date_updated='2020-01-01T00:00:03.0000', id='foo', version='1.0.0', live=Row(foo=Row(value='2'))),
 Row(_dt='0', _compacted=True, date_created='2020-01-01T00:00:01.0000', date_updated='2020-01-01T00:00:02.0000', id='foo', version='1.0.0', live=Row(foo=Row(value='1'))),
 Row(_dt='0', _compacted=True, date_created='2020-01-01T00:00:01.0000', date_updated='2020-01-01T00:00:02.0000', id='chic', version='1.0.0', live=Row(foo=Row(value='1')))]

In [10]:
compacted.collect()

[Row(date_created='2020-01-01T00:00:01.0000', date_updated='2020-01-01T00:00:04.0000', id='foo', version='1.0.0', live=Row(foo=Row(value='3'))),
 Row(date_created='2020-01-01T00:00:01.0000', date_updated='2020-01-01T00:00:02.0000', id='chic', version='1.0.0', live=Row(foo=Row(value='1')))]

In [11]:
import pandas as pd

df = pd.DataFrame([
        dict(
            _dt=_dt(4),
            id="foo",
            version="1.0.0",
            date_updated=_dt(4),
            live={"foo": _v("3")},
            _compacted=False,
             date_created=_dt(1),
             c_action = None,
        ),
        dict(
            _dt=_dt(3),
            id="foo",
            version="1.0.0",
            date_updated=_dt(3),
            live={"foo": _v("2")},
             date_created=_dt(1),
             _compacted=False,
             c_action = None
        ),
        dict(
            _dt=_dt(1),
            _compacted=True,
            id="foo",
            version="1.0.0",
            date_updated=_dt(2),
            date_created=_dt(1),
            live={"foo": _v("1")},
             c_action = None
        ),
        dict(
            _dt=_dt(0),
            _compacted=True,
            id="chic",
            version="1.0.0",
            date_updated=_dt(2),
            date_created=_dt(1),
            live={"foo": _v("1"), "c_expire_dt": "0001-01-01T00:00:00", "core": {"food": "chic", "c_expire_dt": "0001-01-01T00:00:00"}},
            # live={"foo": _v("1"), "core": {"food": "chic"}},
             c_action = "delete"
        ),
    ])

union = spark_session.createDataFrame(df)

In [12]:
df

Unnamed: 0,_dt,id,version,date_updated,live,_compacted,date_created,c_action
0,2020-01-01T00:00:04.0000,foo,1.0.0,2020-01-01T00:00:04.0000,{'foo': {'value': '3'}},False,2020-01-01T00:00:01.0000,
1,2020-01-01T00:00:03.0000,foo,1.0.0,2020-01-01T00:00:03.0000,{'foo': {'value': '2'}},False,2020-01-01T00:00:01.0000,
2,2020-01-01T00:00:01.0000,foo,1.0.0,2020-01-01T00:00:02.0000,{'foo': {'value': '1'}},True,2020-01-01T00:00:01.0000,
3,2020-01-01T00:00:00.0000,chic,1.0.0,2020-01-01T00:00:02.0000,"{'foo': {'value': '1'}, 'c_expire_dt': '0001-0...",True,2020-01-01T00:00:01.0000,delete


In [13]:
union.collect()

[Row(_dt='2020-01-01T00:00:04.0000', id='foo', version='1.0.0', date_updated='2020-01-01T00:00:04.0000', live={'foo': {'value': '3'}}, _compacted=False, date_created='2020-01-01T00:00:01.0000', c_action=None),
 Row(_dt='2020-01-01T00:00:03.0000', id='foo', version='1.0.0', date_updated='2020-01-01T00:00:03.0000', live={'foo': {'value': '2'}}, _compacted=False, date_created='2020-01-01T00:00:01.0000', c_action=None),
 Row(_dt='2020-01-01T00:00:01.0000', id='foo', version='1.0.0', date_updated='2020-01-01T00:00:02.0000', live={'foo': {'value': '1'}}, _compacted=True, date_created='2020-01-01T00:00:01.0000', c_action=None),
 Row(_dt='2020-01-01T00:00:00.0000', id='chic', version='1.0.0', date_updated='2020-01-01T00:00:02.0000', live={'core': {'food': 'chic', 'c_expire_dt': '0001-01-01T00:00:00'}, 'foo': {'value': '1'}, 'c_expire_dt': None}, _compacted=True, date_created='2020-01-01T00:00:01.0000', c_action='delete')]

In [14]:
compacted = compact_entities(spark_session, union)

In [15]:
compacted.dropna("any", subset="id").collect()

[Row(id='foo', version='1.0.0', date_updated='2020-01-01T00:00:04.0000', live={'foo': {'value': '3'}}, date_created='2020-01-01T00:00:01.0000', c_action=None)]

In [16]:
compacted.collect()

[Row(id='foo', version='1.0.0', date_updated='2020-01-01T00:00:04.0000', live={'foo': {'value': '3'}}, date_created='2020-01-01T00:00:01.0000', c_action=None)]

In [17]:
# test c_expire_dt
import pandas as pd

df = pd.DataFrame([
        dict(
            _dt=_dt(4),
            id="foo",
            version="1.0.0",
            date_updated=_dt(4),
            live={"foo": _v("3"), "c_expire_dt": "0001-01-01T00:00:00"},
            _compacted=False,
            date_created=_dt(1),
            c_action = None,
            c_expire_dt = None,
        ),
        dict(
            _dt=_dt(3),
            id="foo",
            version="1.0.0",
            date_updated=_dt(3),
            live={"foo": _v("2"), "c_expire_dt":None},
            date_created=_dt(1),
            _compacted=False,
            c_action = None,
            c_expire_dt = None,
        ),
        dict(
            _dt=_dt(1),
            _compacted=True,
            id="foo",
            version="1.0.0",
            date_updated=_dt(2),
            date_created=_dt(1),
            live={"foo": _v("1"), "c_expire_dt":None},
            c_action = None,
            c_expire_dt = None,
        ),
        dict(
            _dt=_dt(0),
            _compacted=True,
            id="chic",
            version="1.0.0",
            date_updated=_dt(2),
            date_created=_dt(1),
            live={"foo": _v("1"), "c_expire_dt": "0001-01-01T00:00:00", "core": {"food": "chic", "c_expire_dt": "0001-01-01T00:00:00"}},
            c_action = "lol",
            c_expire_dt = "2201-01-01T00:00:00"
        ),
    ])

union = spark_session.createDataFrame(df)

In [18]:
df.live[3]

{'c_expire_dt': '0001-01-01T00:00:00',
 'core': {'c_expire_dt': '0001-01-01T00:00:00', 'food': 'chic'},
 'foo': {'value': '1'}}

In [19]:
_ns("live")

StructType(List(StructField(live,StructType(List(StructField(value,StringType,true))),true)))

In [20]:
union.printSchema()

root
 |-- _dt: string (nullable = true)
 |-- id: string (nullable = true)
 |-- version: string (nullable = true)
 |-- date_updated: string (nullable = true)
 |-- live: map (nullable = true)
 |    |-- key: string
 |    |-- value: map (valueContainsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)
 |-- _compacted: boolean (nullable = true)
 |-- date_created: string (nullable = true)
 |-- c_action: string (nullable = true)
 |-- c_expire_dt: string (nullable = true)



# Schema

In [21]:
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    NumericType,
    StringType,
    StructField,
    StructType,
)

field_type = StructType(
    [
        # Value
        StructField("value", StringType(), nullable=True),
        # c_expire_dt
        StructField("c_expire_dt", StringType(), nullable=True)
    ]
)


def _v(value: Optional[str]) -> dict:
    return {"value": value}


def _dt(secs: int) -> str:
    return f"2020-01-01T00:00:{secs:02}.000"


def _expire_dt(secs: int) -> dict:
    return {"c_expire_dt": f"2020-01-01T00:00:{secs:02}.000"}


def _ns(*attrs):
    ns = StructType()
    for key in attrs:
        ns.add(key, field_type)
    return ns

In [22]:
_ns("live")

StructType(List(StructField(live,StructType(List(StructField(value,StringType,true),StructField(c_expire_dt,StringType,true))),true)))

In [23]:
live={"foo": {"value", "bar"}}