In [1]:
import os
from argparse import ArgumentParser
from dataclasses import dataclass
from enum import Enum
from typing import List
from tqdm import tqdm
import numpy as np
import hsfs

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

Starting Spark application


ID,Application ID,Kind,State,Spark UI,Driver log
43,application_1646303173729_0002,pyspark,idle,Link,Link


SparkSession available as 'spark'.


In [3]:
connection = hsfs.connection()
fs = connection.get_feature_store()

# Storage connector to s3
sc = fs.get_storage_connector("experiment-s3")

Connected. Call `.close()` to terminate connection gracefully.

In [5]:
sc.prepare_spark()

In [6]:
s3_directory = "s3a://" + sc.bucket + "/axel_experiments"

In [7]:
spark = SparkSession.builder.master("local").appName("Data Generation").getOrCreate()

In [8]:
rng = np.random.default_rng(1337)
SECONDS_IN_DAY = 60 * 60 * 24
SECONDS_IN_MONTH = SECONDS_IN_DAY * 30

In [9]:
LEFT_SCHEMA = StructType(
    [
        StructField("id", IntegerType(), False),
        StructField("ts", IntegerType(), False),
        StructField("label", StringType(), False),
    ]
)

RIGHT_SCHEMA = StructType(
    [
        StructField("id", IntegerType(), False),
        StructField("ts", IntegerType(), False),
        StructField("value", StringType(), False),
    ]
)

In [10]:
@dataclass
class DistributionConfiguration:
    mean: float
    sd: float


class TimestampDistribution(Enum):
    Normal = "normal"
    Uniform = "uniform"

    def __str__(self) -> str:
        return self.value


In [11]:
def limit_value(value, min_v, max_v):
    return min(max(min_v, value), max_v)

In [33]:
def generate_left_table(unique_ids: int, max_timestamp: int) -> DataFrame:
    # Create RDD of all initial RDDs
    id_rdd = spark.sparkContext.parallelize([id for id in range(1, unique_ids + 1)])
    left_rdd = id_rdd.map(lambda id : (
        id, 
        rng.integers(low=0, high=max_timestamp).item(),
        "{}".format(id),
    ))
    return left_rdd.toDF(LEFT_SCHEMA)

In [None]:
def generate_right_table(
    unique_ids: int,
    max_timestamp: int,
    events_per_id_confs: List[DistributionConfiguration],
    timestamp_distributions: List[TimestampDistribution],
):
    id_rdd = spark.sparkContext.parallelize([id for id in range(1, unique_ids + 1)])
    def mapping_func(id):
        # Pick a event generation conf at random
        events_per_id_conf_idx = rng.integers(
            low=0, high=len(events_per_id_confs), size=1
        )[0]
        events_per_id_conf = events_per_id_confs[events_per_id_conf_idx]
        events = max(rng.normal(loc=events_per_id_conf.mean, scale=events_per_id_conf.sd, size=1)[0].astype(int), 0)  # type: ignore

        # Pick a timestamp distribution configuration
        timestamp_distribution_idx = rng.integers(
            low=0, high=len(timestamp_distributions), size=1
        )[0]
        timestamp_distribution = timestamp_distributions[timestamp_distribution_idx]

        if timestamp_distribution is TimestampDistribution.Uniform:
            timestamps = rng.integers(low=0, high=max_timestamp, size=events).tolist()
        elif timestamp_distribution is TimestampDistribution.Normal:
            mean = rng.uniform(low=0, high=max_timestamp, size=1)[0]
            sd = rng.uniform(low=SECONDS_IN_DAY * 5, high=SECONDS_IN_DAY * 15)
            timestamps = rng.normal(mean, sd, events).astype(int).tolist()  # type: ignore
        else:
            raise Exception(
                "Invalid timestamp distribution: {}".format(timestamp_distribution)
            )
        return [
            (id, limit_value(timestamp, 0, max_timestamp  - 1), "{}".format(id))
            for timestamp in timestamps
        ]

    right_rdd = id_rdd.flatMap(mapping_func)
    return right_rdd.toDF(RIGHT_SCHEMA)

"""
generate_right_table(1_000_000, SECONDS_IN_MONTH * 12, [
    DistributionConfiguration(20, 2),
    DistributionConfiguration(80, 8),
], [TimestampDistribution.Uniform, TimestmapDistribution.Normal]).count()
"""

In [14]:
def save(ids, max_ts, left: DataFrame, right: DataFrame):
    directory = s3_directory + "/raw/{}-{}".format(ids, max_ts)
    left.write.parquet(directory + "/left.parquet", mode="overwrite")
    # left.to_csv(directory + "/debug_left.csv")
    right.write.parquet(directory + "/right.parquet", mode="overwrite")
    # right.to_csv(directory + "/debug_right.csv")

In [15]:
UNIQUE_IDS = [10_000, 100_000, 1_000_000, 10_000_000]
INTERVALS = [
    (SECONDS_IN_MONTH * 12, "1_year"),
]
FEATURE_UPDATES = [
    DistributionConfiguration(20, 2),
    DistributionConfiguration(80, 8),
]
NO_FEATURES = [1, 2]
DISTRIBUTIONS = [TimestampDistribution.Normal, TimestampDistribution.Uniform]

In [16]:
def generate_all():
    progress = tqdm(total=sum(UNIQUE_IDS))
    for interval in INTERVALS:
        for unique_ids in UNIQUE_IDS:
            left = generate_left_table(unique_ids, interval[0])
            right = generate_right_table(
                unique_ids, interval[0], FEATURE_UPDATES, DISTRIBUTIONS
            )
            save(
                unique_ids,
                interval[1],
                left,
                right,
            )
            progress.update(unique_ids)

In [17]:
generate_all()

An error was encountered:
An error occurred while calling o257.parquet.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:231)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:188)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:108)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:106)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.doExecute(commands.scala:131)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
	