In [None]:
"""
This notebook demonstrates how to reshape the sampled Sentinel-2 data into a format suitable for modeling. 
The final dataset contains the same information as the original samples, but is reorganized as one row per pixel and year
"""

import sys
import os

project_root = os.path.abspath("..")
sys.path.append(project_root)

from preprocessing.spark_session import spark
from preprocessing.transform_input_data import (
    create_image_map,
    group_time_series,
)

# UDFs must be defined in the notebook so that Spark can serialize them
from pyspark.sql.types import ArrayType, StringType, BinaryType
import pyspark.sql.functions as F
from datetime import datetime

def get_sorted_input(dicts_list):
    sd = sorted(dicts_list, key=lambda x: [*x][0])
    only_nums, tiles, img_dates, scl_vals = [], [], [], []
    for item in sd:
        keyi = [*item][0]
        only_nums.extend(item.get(keyi)[:-2])
        scl_vals.append(item.get(keyi)[-2])
        tiles.append(item.get(keyi)[-1])
        img_dates.append(keyi.strftime('%Y-%m-%d'))
    return [
        ','.join(map(str, only_nums)),
        ','.join(map(str, tiles)),
        ','.join(img_dates),
        ','.join(map(str, scl_vals))
    ]

def convert_bytes(bands_str: str) -> bytes:
    band_vals = b''
    for num in bands_str.split(','):
        band_vals += int(float(num)).to_bytes(2, 'big')
    return band_vals

def convert_bytes_scl_vals(scl_vals_in: str) -> bytes:
    scl_vals_bstr = b''
    for num in scl_vals_in.split(','):
        scl_vals_bstr += int(num).to_bytes(1, 'big')
    return scl_vals_bstr

def convert_string_utf8(s: str) -> bytes:
    return s.encode('UTF-8')

def date_array_2_int(date_arr: str) -> bytes:
    dates = date_arr.split(',')
    date_vals = b''
    for x in dates:
        days = (datetime.strptime(x, '%Y-%m-%d').date() - datetime(1970, 1, 1).date()).days
        date_vals += int(days).to_bytes(2, 'big')
    return date_vals

# Register local UDFs
get_sorted_input_udf = F.udf(get_sorted_input, ArrayType(StringType()))
convert_bytes_udf = F.udf(convert_bytes, BinaryType())
convert_bytes_scl_vals_udf = F.udf(convert_bytes_scl_vals, BinaryType())
convert_string_utf8_udf = F.udf(convert_string_utf8, BinaryType())
date_array_2_int_udf = F.udf(date_array_2_int, BinaryType())

# Main arguments
input_uri = "../data/s2_unique_scene.parquet/"
output_uri = "../data/CDL_unique_scene_ts.parquet/"
bbox = "484932, 1401912, 489035, 1405125"
year = "2019"

# Group pixels by scene date
df = spark.read.parquet(f"{input_uri}bbox={bbox}/year={year}")
df = create_image_map(df)
df = group_time_series(df)

# Flatten
df = df.withColumn("inputs_lists", get_sorted_input_udf(F.col("image_dicts_list"))).drop("image_dicts_list")
df = df.withColumn("bands", F.col("inputs_lists")[0])
df = df.withColumn("tiles", F.col("inputs_lists")[1])
df = df.withColumn("img_dates", F.col("inputs_lists")[2])
df = df.withColumn("scl_vals", F.col("inputs_lists")[3])
df = df.drop("inputs_lists")

# Binary conversion
ts = df \
    .withColumn("bands", convert_bytes_udf(F.col("bands"))) \
    .withColumn("img_dates", date_array_2_int_udf(F.col("img_dates"))) \
    .withColumn("tiles", convert_string_utf8_udf(F.col("tiles"))) \
    .withColumn("CDL", convert_string_utf8_udf(F.col("CDL"))) \
    .withColumn("scl_vals", convert_bytes_scl_vals_udf(F.col("scl_vals"))) \
    .withColumn("bbox", F.lit(bbox.encode("UTF-8"))) \
    .withColumn("year", F.lit(year))

In [None]:
# Check final output
ts.show(truncate=False)

In [None]:
# write it out
ts.write.partitionBy(['bbox', 'year']).mode("append").parquet(output_uri)