In [None]:
from IPython.core.display import HTML

display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql.types import FloatType
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors, VectorUDT
import numpy as np
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.feature import VectorAssembler

spark = SparkSession.builder.config("spark.driver.memory", "10g").getOrCreate()

In [None]:
replace_infs_udf = F.udf(lambda x: x if not np.isinf(x) else None, FloatType())

toDense = lambda v: Vectors.dense(v.toArray())
toDenseUdf = F.udf(toDense, VectorUDT())

In [None]:
feature_names = []
for ff in ["train", "test"]:
    df = (
        spark.read.option("inferSchema", False)
        .option("delimiter", " ")
        .csv(f"./datasets/istella-letor/full/{ff}.txt")
        .drop("_c222")
    )

    for col in df.columns:
        if col != "_c0":
            df = df.withColumn(col, F.split(col, ":").getItem(1))

        if col not in ["_c0", "_c1"]:
            if ff == "train":
                feature_names.append(col)
            df = df.withColumn(col, replace_infs_udf(F.col(col).cast("float")))

    df = df.fillna(0.0, subset=feature_names)
    df = df.withColumnRenamed("_c0", "label")
    df = df.withColumnRenamed("_c1", "qid")

    df = df.withColumn("label", F.col("label").cast("int"))
    df = df.withColumn("qid", F.col("qid").cast("int"))

    assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
    df = assembler.transform(df)

    if ff == "train":
        scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures")
        scaler = scaler.fit(df)
        df = scaler.transform(df)
    else:
        df = scaler.transform(df)

    # break
    (
        df.select("qid", "label", toDenseUdf(F.col("features")).alias("features"))
        # .select("qid", "label", "scaledFeatures")
        .write.mode("overwrite")
        .partitionBy("qid")
        .parquet(f"./datasets/istella-letor/{ff}_parquet")
    )