In [None]:
import os

from pyspark.ml.feature import MinMaxScaler, StandardScaler
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_list, struct, udf
from pyspark.sql.types import ArrayType, DoubleType, StringType, StructField, StructType

# === Parameters ===
input_path = "/path/to/input.parquet"
output_path = "/path/to/output_scaled.parquet"
scaler_save_path = "/path/to/scalers.parquet"
use_minmax = True  # If False, StandardScaler will be used

# === Spark Session ===
spark = SparkSession.builder.appName("ScalerPerGroup").getOrCreate()

# === Load Data ===
df = spark.read.parquet(input_path).select("unique_id", "ds", "y")

from pyspark.ml.linalg import Vectors, VectorUDT

# Convert y to vector for scaler
from pyspark.sql.functions import udf

to_vector = udf(lambda x: Vectors.dense([x]), VectorUDT())
df_vec = df.withColumn("y_vec", to_vector("y"))

# === Scale per unique_id ===


def scale_group(pdf, scaler_type="minmax"):
    import pandas as pd
    from sklearn.preprocessing import MinMaxScaler, StandardScaler

    uid = pdf["unique_id"].iloc[0]
    y_vals = pdf["y"].values.reshape(-1, 1)

    if scaler_type == "minmax":
        scaler = MinMaxScaler()
    else:
        scaler = StandardScaler()

    scaled = scaler.fit_transform(y_vals)
    pdf["y_scaled"] = scaled.flatten()

    # Save scaler stats
    scaler_info = {
        "unique_id": uid,
        "mean": scaler.mean_[0] if hasattr(scaler, "mean_") else None,
        "std": scaler.scale_[0] if hasattr(scaler, "scale_") else None,
        "min": scaler.data_min_[0] if hasattr(scaler, "data_min_") else None,
        "max": scaler.data_max_[0] if hasattr(scaler, "data_max_") else None,
    }

    return pd.DataFrame(pdf), pd.DataFrame([scaler_info])


# Apply per-group scaling
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
    DoubleType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)

schema_scaled = StructType(
    [
        StructField("unique_id", StringType()),
        StructField("ds", TimestampType()),
        StructField("y", DoubleType()),
        StructField("y_vec", VectorUDT()),
        StructField("y_scaled", DoubleType()),
    ]
)

schema_scaler = StructType(
    [
        StructField("unique_id", StringType()),
        StructField("mean", DoubleType()),
        StructField("std", DoubleType()),
        StructField("min", DoubleType()),
        StructField("max", DoubleType()),
    ]
)


def grouped_scaler(pdf_iter):
    for pdf in pdf_iter:
        scaled_df, scaler_info_df = scale_group(
            pdf, scaler_type="minmax" if use_minmax else "standard"
        )
        yield scaled_df


def grouped_scaler_info(pdf_iter):
    for pdf in pdf_iter:
        _, scaler_info_df = scale_group(pdf, scaler_type="minmax" if use_minmax else "standard")
        yield scaler_info_df


# Apply to groups
scaled_df = df_vec.groupBy("unique_id").applyInPandas(grouped_scaler, schema=schema_scaled)
scaler_info_df = df_vec.groupBy("unique_id").applyInPandas(
    grouped_scaler_info, schema=schema_scaler
)

# === Save Output ===
scaled_df.select("unique_id", "ds", "y_scaled").write.mode("overwrite").parquet(output_path)
scaler_info_df.write.mode("overwrite").parquet(scaler_save_path)

spark.stop()