In [18]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, MapType, IntegerType


from vectorize import vectorize_routes

spark = SparkSession.builder.appName("Pay Routes").getOrCreate()


schema = StructType([
    StructField("route", ArrayType(
        StructType([
            StructField("from_city", StringType(), nullable=True),
            StructField("to_city", StringType(), nullable=True),
            StructField("merch", MapType(StringType(), IntegerType()), nullable=True),
        ])
    ), nullable=True),
    StructField("uuid", StringType(), nullable=True),
])

def load_json_to_spark(file_name):
    df = spark.read.json(file_name, schema=schema)
    return df


planned_routes_df = load_json_to_spark("planned_routes.json")
actual_routes_df = load_json_to_spark("actual_routes.json")

In [19]:
planned_routes_df.show()
actual_routes_df.show()

+--------------------+--------------------+
|               route|                uuid|
+--------------------+--------------------+
|[{Utrecht, The Ha...|ca5b2fa758194c7da...|
|[{Eindhoven, Delf...|0de74fb5d653403f8...|
|[{Eindhoven, Utre...|c5baf719b94a48238...|
|[{Amsterdam, Delf...|4557a1ab54d34989b...|
|[{Delft, Tilburg,...|f1e90ad34080440d9...|
|[{Delft, The Hagu...|e461200c6bbf4c02b...|
|[{Tilburg, Delft,...|b6473c53e8484c69a...|
|[{Eindhoven, Rott...|20c8eb90c26d40b7b...|
|[{Utrecht, The Ha...|83b09d82a045448ab...|
|[{Utrecht, Eindho...|a7f479c9817941cfb...|
|[{Rotterdam, Amst...|e7f8b68cf37345c89...|
|[{The Hague, Eind...|dd9b1a98315141fa8...|
|[{Groningen, The ...|09bb3c81d23d4a4f9...|
|[{Tilburg, The Ha...|04ce53e1189b47b5a...|
|[{Tilburg, The Ha...|8e7f9aac22a44fe38...|
|[{Eindhoven, Gron...|486d433fdb844987a...|
|[{Groningen, Tilb...|137c932bff4e469a8...|
|[{Delft, Utrecht,...|cefd484a2e1045c79...|
|[{Eindhoven, Utre...|59c336008e9c4ef1a...|
|[{Groningen, Tilb...|757257941f

In [20]:
planned_routes_df.printSchema()

root
 |-- route: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- from_city: string (nullable = true)
 |    |    |-- to_city: string (nullable = true)
 |    |    |-- merch: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: integer (valueContainsNull = true)
 |-- uuid: string (nullable = true)



In [40]:
import yaml
from itertools import product, chain
from pyspark.sql.functions import create_map, lit, udf
from pyspark.sql.types import IntegerType


with open('src/data_gen_config.yaml', encoding="utf-8") as f:
    config = yaml.safe_load(f)

merch_items = config["merch_items"]
cities = config["cities"]
merch_item_min = config["merch_sampler_map"]["low"]
merch_item_max = config["merch_sampler_map"]["high"]

combinations = list(product(cities, cities, merch_items))
# Make sure from_city and to_city are not the same
combinations = [
    (from_city, to_city, merch)
    for from_city, to_city, merch in combinations
    if from_city != to_city
]
vector_size = len(combinations)
# Generate a hashmap from combination to index
vector_map = {combination: index for index, combination in enumerate(combinations)}
print(vector_map)

@udf(IntegerType())
def get_vector_index(from_city, to_city, merch_name):
    return vector_map.get((from_city, to_city, merch_name))

{('Amsterdam', 'Rotterdam', 'Apple'): 0, ('Amsterdam', 'Rotterdam', 'Pear'): 1, ('Amsterdam', 'Rotterdam', 'Banana'): 2, ('Amsterdam', 'Rotterdam', 'Kiwi'): 3, ('Amsterdam', 'Rotterdam', 'Orange'): 4, ('Amsterdam', 'Rotterdam', 'Mandarin'): 5, ('Amsterdam', 'Rotterdam', 'Strawberry'): 6, ('Amsterdam', 'Rotterdam', 'Mango'): 7, ('Amsterdam', 'The Hague', 'Apple'): 8, ('Amsterdam', 'The Hague', 'Pear'): 9, ('Amsterdam', 'The Hague', 'Banana'): 10, ('Amsterdam', 'The Hague', 'Kiwi'): 11, ('Amsterdam', 'The Hague', 'Orange'): 12, ('Amsterdam', 'The Hague', 'Mandarin'): 13, ('Amsterdam', 'The Hague', 'Strawberry'): 14, ('Amsterdam', 'The Hague', 'Mango'): 15, ('Amsterdam', 'Utrecht', 'Apple'): 16, ('Amsterdam', 'Utrecht', 'Pear'): 17, ('Amsterdam', 'Utrecht', 'Banana'): 18, ('Amsterdam', 'Utrecht', 'Kiwi'): 19, ('Amsterdam', 'Utrecht', 'Orange'): 20, ('Amsterdam', 'Utrecht', 'Mandarin'): 21, ('Amsterdam', 'Utrecht', 'Strawberry'): 22, ('Amsterdam', 'Utrecht', 'Mango'): 23, ('Amsterdam', 'De

In [None]:
from pyspark.ml.linalg import Vectors, VectorUDT

@udf(VectorUDT())
def vectorize(vector_elements):
    

In [48]:
from pyspark.sql.functions import explode, map_keys, map_values, col, collect_list, struct
from pyspark.ml.linalg import Vectors, VectorUDT


# Step 1: Explode the merch column into multiple rows
df = planned_routes_df.select("*", explode(planned_routes_df.route).alias("individual_route"))

# Now individual_route is a column of struct type, so you can access its fields like this:
df = df.select("*", 
               df.individual_route.from_city.alias("from_city"), 
               df.individual_route.to_city.alias("to_city"),
               df.individual_route.merch.alias("merch"))
df = df.select("*", explode(df.merch).alias("merch_name", "merch_weight"))

df = df.withColumn("normalized_weight", (df.merch_weight - merch_item_min) / (merch_item_max - merch_item_min))
df = df.withColumn("vector_index", get_vector_index(col("from_city"), col("to_city"), col("merch_name")))

# # Now each row of the DataFrame represents one entry of the vector for a route
# # If you want to re-aggregate the data into one row per original route, you can group by uuid:
#df = df.groupBy("uuid").agg(collect_list(struct("vector_index", "normalized_weight")).alias("vector_elements"))
#df = df.withColumn("vector", col("vector_elements").apply(lambda x: Vectors.sparse(vector_size, [e.vector_index for e in x], [e.normalized_weight for e in x])))
# # And finally convert the vector elements to a sparse vector
#df = df.withColumn("vector", Vectors.sparse(vector_size, df.vector_elements))
df.show()

# Step 2: Normalize the merch_weight and map to the vector index
# df = df.withColumn("normalized_weight", (df.merch_weight - merch_item_min) / (merch_item_max - merch_item_min))
# df = df.withColumn("vector_index", vector_map[(df.from_city, df.to_city, df.merch_name)])

# # Step 3: Re-aggregate the data (optional)
# df = df.groupBy("uuid").agg(collect_list(struct("vector_index", "normalized_weight")).alias("vector_elements"))

# # Convert the vector elements to a sparse vector
# df = df.withColumn("vector", Vectors.sparse(vector_size, df.vector_elements))


+--------------------+--------------------+--------------------+---------+---------+--------------------+----------+------------+-----------------+------------+
|               route|                uuid|    individual_route|from_city|  to_city|               merch|merch_name|merch_weight|normalized_weight|vector_index|
+--------------------+--------------------+--------------------+---------+---------+--------------------+----------+------------+-----------------+------------+
|[{Utrecht, The Ha...|ca5b2fa758194c7da...|{Utrecht, The Hag...|  Utrecht|The Hague|{Banana -> 61, St...|    Banana|          61|             0.22|         186|
|[{Utrecht, The Ha...|ca5b2fa758194c7da...|{Utrecht, The Hag...|  Utrecht|The Hague|{Banana -> 61, St...|Strawberry|          95|              0.9|         190|
|[{Utrecht, The Ha...|ca5b2fa758194c7da...|{Utrecht, The Hag...|  Utrecht|The Hague|{Banana -> 61, St...|     Mango|          89|             0.78|         191|
|[{Utrecht, The Ha...|ca5b2fa75819