In [None]:
import pandas as pd
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql import functions as f

from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampNTZType, FloatType, IntegerType


from teehr import Evaluation
from pathlib import Path
import shutil


In [None]:
# Set a path to the directory where the evaluation will be created
TEST_STUDY_DIR = Path(Path().home(), "temp", "test_study")
shutil.rmtree(TEST_STUDY_DIR, ignore_errors=True)
TEST_STUDY_DIR.mkdir(parents=True, exist_ok=True)

# Set a path to the directory where the test data is stored
TEST_DATA_DIR = Path("/home/matt/repos/teehr/tests/data/v0_3_test_study")
GEOJSON_GAGES_FILEPATH = Path(TEST_DATA_DIR, "geo", "gages.geojson")
PRIMARY_TIMESERIES_FILEPATH = Path(
    TEST_DATA_DIR, "timeseries", "test_short_obs.parquet"
)
CROSSWALK_FILEPATH = Path(TEST_DATA_DIR, "geo", "crosswalk.csv")
SECONDARY_TIMESERIES_FILEPATH = Path(
    TEST_DATA_DIR, "timeseries", "test_short_fcast.parquet"
)
GEO_FILEPATH = Path(TEST_DATA_DIR, "geo")

In [None]:
# Create an Evaluation object
eval = Evaluation(dir_path=TEST_STUDY_DIR)

# Enable logging
eval.enable_logging()

# Clone the template
eval.clone_template()

In [None]:
pd.read_parquet(PRIMARY_TIMESERIES_FILEPATH )

In [None]:
schema = StructType(
    [
        StructField('reference_time', TimestampNTZType(), True),
        StructField('value_time', TimestampNTZType(), True),
        StructField('value', DoubleType(), True),
        StructField('variable_name', StringType(), True),
        StructField('configuration_name', StringType(), True),
        StructField('unit_name', StringType(), True),
        StructField('location_id', StringType(), True)
    ]
)

In [None]:
timeseries = (
    eval.spark.read.format("parquet")
    # .schema(schema)
    .load(str(PRIMARY_TIMESERIES_FILEPATH))
)

In [None]:
# current: desired
# https://datamadness.medium.com/renaming-columns-in-pyspark-fe71f7111454
rename_dict = {
  'measurement_unit':'unit_name',
  'configuration':'configuration_name',
}

timeseries = (
  timeseries
  .select([col(c).alias(rename_dict.get(c, c)) for c in timeseries.columns])
)

In [None]:
timeseries.show()

In [None]:
# https://datamadness.medium.com/casting-data-types-in-pyspark-f95d1326449b

# data_type_map = {
#   'reference_time': TimestampNTZType(),
#   'value_time': TimestampNTZType(),
#   'value': DoubleType(),
#   'variable_name': StringType(),
#   'unit_name': DoubleType(),
#   'configuration_name': StringType(),
#   'location_id': StringType()
# }


# timeseries = (
#     timeseries
#   .select([col(column_schema[0]).cast(data_type_map.get(column_schema[0], column_schema[1])) for column_schema in timeseries.dtypes])
# )

In [None]:
timeseries.show()

In [None]:
(
    timeseries.write
    .partitionBy("configuration_name", "variable_name")
    .format("parquet")
    .mode("overwrite")
    .save(str(Path(eval.primary_timeseries_dir)))
)

In [None]:
import pandera as pa

In [None]:
# define schema
# schema = pa.DataFrameSchema({
#     "reference_time": pa.Column(str),
#     "value_time": pa.Column(str),
#     "value": pa.Column(float),
#     "variable_name": pa.Column(str),
#     "configuration_name": pa.Column(str),
#     "unit_name": pa.Column(str),
#     "location_id": pa.Column(str)
# })

# validated_df = schema(timeseries)
# print(validated_df)

In [None]:
import pandera.pyspark as pa
import pyspark.sql.types as T

from decimal import Decimal
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pandera.pyspark import DataFrameModel

spark = SparkSession.builder.getOrCreate()

class PanderaSchema(DataFrameModel):
    id: T.IntegerType() = pa.Field(gt=5)
    product_name: T.StringType() = pa.Field(str_startswith="B")
    price: T.DecimalType(20, 5) = pa.Field()
    description: T.ArrayType(T.StringType()) = pa.Field()
    meta: T.MapType(T.StringType(), T.StringType()) = pa.Field()

data = [
    (5, "Bread", Decimal(44.4), ["description of product"], {"product_category": "dairy"}),
    (15, "Butter", Decimal(99.0), ["more details here"], {"product_category": "bakery"}),
]

spark_schema = T.StructType(
    [
        T.StructField("id", T.IntegerType(), False),
        T.StructField("product", T.StringType(), False),
        T.StructField("price", T.DecimalType(20, 5), False),
        T.StructField("description", T.ArrayType(T.StringType(), False), False),
        T.StructField(
            "meta", T.MapType(T.StringType(), T.StringType(), False), False
        ),
    ],
)
df = spark.createDataFrame(data, spark_schema)
df.show()

In [None]:
df_out = PanderaSchema.validate(check_obj=df)
df_out

In [None]:
df_out.show()

In [None]:
class PanderaTimeseriesSchema(DataFrameModel):
    reference_time: T.TimestampType = pa.Field(nullable=True)
    value_time: T.TimestampType = pa.Field()
    value: T.DoubleType = pa.Field()
    variable_name: T.StringType = pa.Field()
    configuration_name: T.StringType = pa.Field()
    unit_name: T.StringType = pa.Field(isin=["m^3/s"])
    location_id: T.StringType = pa.Field()

In [None]:
ts_out = PanderaTimeseriesSchema.validate(check_obj=timeseries)
ts_out.show()

In [None]:
import json

df_out_errors = ts_out.pandera.errors
print(json.dumps(dict(df_out_errors), indent=4))

In [None]:
# define schema
schema = pa.DataFrameSchema({
    "reference_time": pa.Column(T.TimestampNTZType, nullable=True),
    "value_time": pa.Column(T.TimestampNTZType),
    "value": pa.Column(T.FloatType, coerce=True),
    "variable_name": pa.Column(T.StringType),
    "configuration_name": pa.Column(T.StringType),
    "unit_name": pa.Column(T.StringType, pa.Check.isin(["m^3/s"])),
    "location_id": pa.Column(T.StringType)
})

validated_df = schema(timeseries)
print(validated_df)

In [None]:
import json

df_out_errors = validated_df.pandera.errors
print(json.dumps(dict(df_out_errors), indent=4))

In [None]:
list(schema.columns.keys())