In [1]:
from pysparkmeos.UDT.MeosDatatype import TGeogPointInstUDT, TFloatInstUDT, STBoxUDT

from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, IntegerType, TimestampType, DoubleType, FloatType
from pyspark.serializers import PickleSerializer
from pyspark.sql.functions import from_unixtime, col, udf, collect_list, count, monotonically_increasing_id
from pymeos import pymeos_initialize, pymeos_finalize, TGeogPointInst, TGeogPointSeq, TFloat, STBox, TFloatInst, TPointInst
from pysparkmeos.partitions.grid.grid_partitioner import GridPartition
from pysparkmeos.utils.udt_appender import udt_append
from pysparkmeos.UDF.udf import *

from pymeos import TPoint

from pyspark.sql.functions import max as psmax, min as psmin

import random, datetime

In [2]:
def main():
    # Initialize PyMEOS
    pymeos_initialize("UTC")

    # Initialize a Spark session
    spark = SparkSession.builder \
        .appName("PySpark UDF Example with PyMEOS") \
        .master("local[*]") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.default.parallelism", 36) \
        .config("spark.executor.memory", "4g") \
        .config("spark.driver.memory", "4g") \
        .getOrCreate()

    # spark.sparkContext.setLogLevel("DEBUG")

    # Append the UDT mapping to the PyMEOS classes
    udt_append()

    @udf(returnType=FloatType())
    def get_point_x(point: TPointInst):
        return point.x().start_value()

    @udf(returnType=FloatType())
    def get_point_y(point: TPointInst):
        return point.y().start_value()

    @udf(returnType=FloatType())
    def get_point_z(point: TPointInst):
        return point.z().start_value()
        
    @udf(returnType=TimestampType())
    def get_point_timestamp(point: TPointInst):
        ts = point.start_timestamp()
        return ts

    @udf(returnType=STBoxUDT())
    def bounds_as_box(xmin: TFloat, xmax: TFloat, ymin: TFloat, ymax: TFloat, tmin: datetime=None, tmax: datetime=None, zmin: TFloat=None, zmax: TFloat=None):
        if tmin and zmin:
            return STBox(xmax=xmax, xmin=xmin, ymax=ymax, ymin=ymin, tmin=tmin, tmax=tmax, zmin=zmin, zmax=zmax)
        elif tmin and not zmin:
            return STBox(xmax=xmax, xmin=xmin, ymax=ymax, ymin=ymin, tmin=tmin, tmax=tmax)
        elif zmin and not tmin:
            return STBox(xmax=xmax, xmin=xmin, ymax=ymax, ymin=ymin, zmin=zmin, zmax=zmax)

    
    spark.udf.register("create_point_udf", create_point_udf)
    
    # Get the value of 'spark.default.parallelism'
    default_parallelism = spark.sparkContext.getConf().get("spark.default.parallelism")
    print(f"spark.default.parallelism: {default_parallelism}")
    
    # Read data from a CSV file
    data_path = "../../small_states_2022-06-27-00.csv"  # Update this with your CSV file path
    df = spark.read.csv(data_path, header=True, inferSchema=True).select("icao24", "time", "lat", "lon")
    
    df = df.dropna(subset=["lat", "lon"])

    #df.select(
    #    psmax(col("lon")).alias("max_lon"),
    #    psmin(col("lon")).alias("min_lon"),
    #    psmax(col("lat")).alias("max_lat"),
    #    psmin(col("lat")).alias("min_lat"),
    #).show()
    
    # Convert the 'time' column to the correct format
    df = df.withColumn("time", from_unixtime(col("time"), "yyyy-MM-dd' 'HH:mm:ss")).withColumn("Point", create_point_udf("lat", "lon", "time")).withColumn("x", get_point_x("Point")).withColumn("y", get_point_y("Point")).withColumn("t", get_point_timestamp("Point"))
    df = df.withColumn("PointStr", col("Point").cast("string")).withColumn("id", monotonically_increasing_id())
    
    df.show(3, truncate=False)

    #max_x = df.agg(psmax("x"))
    #print(max_x.collect()[0])

    #max_x2 = df.select(psmax("x"))
    #print(max_x.collect()[0])
    
    boundaries = df.agg(
        psmax(col("x")).alias("max_x"),
        psmin(col("x")).alias("min_x"),
        psmax(col("y")).alias("max_y"),
        psmin(col("y")).alias("min_y"),
        psmax(col("t").cast("timestamp")).alias("max_t"),
        psmin(col("t").cast("timestamp")).alias("min_t"),
        psmax(col("lat")).alias("max_lat"),
        psmin(col("lat")).alias("min_lat")
    )

    # boundaries.show()
    
    df.printSchema()

    # Function to count rows per partition
    def count_in_partition(iterator):
        cnt = 0
        for _ in iterator:
            cnt += 1
        return cnt

    # Get the bounding box of the data
    #bounds = boundaries.select(bounds_as_box("min_x", "max_x", "min_y", "max_y", "min_t", "max_t").alias("bounds"))
    #bounds.show(truncate=False)
    #bounds = bounds.collect()[0]['bounds']
    #print(bounds.__str__(), type(bounds))
    
    bounds = STBox(
        "STBOX XT(((-177.02969360351562,-46.421356201171875),(177.816650390625,70.29727935791016)),[2022-06-27 00:00:00+00, 2022-06-27 00:15:00+00])",
        geodetic=True)

    
    # Now we calculate the grid and partition accordingly
    gp = GridPartition(cells_per_side=3, bounds=bounds)
    print(gp.num_partitions())
    
    dfpointstr = df.select("id", "PointStr")
    gridstr = gp.gridstr
    datardd = dfpointstr.rdd.map(lambda row: (row['id'], row['PointStr'], gridstr))
    datardd.partitionBy(gp.num_partitions(), gp.get_partition)

    # Take the first 5 elements and print them
    first_five_items = datardd.take(5)
    for item in first_five_items:
        print(item[0], len(item))
    
    # Using mapPartitionsWithIndex to count rows per partition
    partition_counts = datardd.mapPartitions(count_in_partition).collect()

    # Print the results
    for i, cnt in enumerate(partition_counts):
        print(f"Partition {partition_id} has {cnt} rows")
    
    pymeos_finalize()
    spark.stop()

if __name__ == "__main__":
    main()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/24 15:08:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


spark.default.parallelism: 36


24/04/24 15:09:05 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
                                                                                

+------+-------------------+------------------+------------------+------------------------------------------------------------------+---------+--------+-------------------+------------------------------------------------------------------+---+
|icao24|time               |lat               |lon               |Point                                                             |x        |y       |t                  |PointStr                                                          |id |
+------+-------------------+------------------+------------------+------------------------------------------------------------------+---------+--------+-------------------+------------------------------------------------------------------+---+
|34718e|2022-06-27 00:00:00|40.87294006347656 |1.9229736328125   |POINT(1.9229736328125 40.87294006347656)@2022-06-27 00:00:00+00   |1.9229736|40.87294|2022-06-27 00:00:00|POINT(1.9229736328125 40.87294006347656)@2022-06-27 00:00:00+00   |0  |
|ac6364|2022-06-27 00:00

                                                                                

0 3
1 3
2 3
3 3
4 3


24/04/24 15:09:36 ERROR Executor: Exception in task 0.0 in stage 4.0 (TID 27)24]
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1239, in process
    serializer.dump_stream(out_iter, outfile)
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 272, in dump_stream
    iterator = iter(iterator)
TypeError: 'int' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNe

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 4.0 failed 1 times, most recent failure: Lost task 2.0 in stage 4.0 (TID 29) (8eace0f42ad5 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1239, in process
    serializer.dump_stream(out_iter, outfile)
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 272, in dump_stream
    iterator = iter(iterator)
TypeError: 'int' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2463)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:195)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:75)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:52)
	at java.base/java.lang.reflect.Method.invoke(Method.java:580)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1239, in process
    serializer.dump_stream(out_iter, outfile)
  File "/usr/local/lib/python3.9/dist-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 272, in dump_stream
    iterator = iter(iterator)
TypeError: 'int' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	... 1 more
