In [1]:
import pandas as pd


Waiting for a Spark session to start...
Spark Initialization Done! ApplicationId = app-20181210222736-0000
KERNEL_ID = 92943368-22c8-40fc-9ee7-b76d44e0dfc4


In [2]:
trips = pd.DataFrame({
    "origin": [
        "PMI",
        "ATH",
        "JFK",
        "HND"
    ],
    "destination": [
        "OPO",
        "BCN",
        "MAD",
        "LAX"
    ],
    "internal_flight_ids": [
        [2, 1],
        [3],
        [5, 4, 6],
        [8, 9, 7, 0]
    ]    
})
trips = spark.createDataFrame(trips)

In [3]:
trips.show()

+-----------+-------------------+------+
|destination|internal_flight_ids|origin|
+-----------+-------------------+------+
|        OPO|             [2, 1]|   PMI|
|        BCN|                [3]|   ATH|
|        MAD|          [5, 4, 6]|   JFK|
|        LAX|       [8, 9, 7, 0]|   HND|
+-----------+-------------------+------+



In [4]:
flights = pd.DataFrame({
    "internal_flight_id": [
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9
    ],
    "public_flight_number": [
        "FR5763", "UT9586", "B4325", "RW35675", "LP656",
        "NB4321", "CX4599", "AZ8844", "KH8851", "OP8777"
    ]
})
flights = spark.createDataFrame(flights)

In [5]:
flights.show()

+------------------+--------------------+
|internal_flight_id|public_flight_number|
+------------------+--------------------+
|                 0|              FR5763|
|                 1|              UT9586|
|                 2|               B4325|
|                 3|             RW35675|
|                 4|               LP656|
|                 5|              NB4321|
|                 6|              CX4599|
|                 7|              AZ8844|
|                 8|              KH8851|
|                 9|              OP8777|
+------------------+--------------------+



In [6]:
from pyspark.sql.functions import col, explode, posexplode, collect_list, monotonically_increasing_id
from pyspark.sql.window import Window

In [7]:
#Wrong implementation

trips = trips \
    .withColumn("row_id", monotonically_increasing_id())

In [8]:
trips.show()

+-----------+-------------------+------+----------+
|destination|internal_flight_ids|origin|    row_id|
+-----------+-------------------+------+----------+
|        OPO|             [2, 1]|   PMI|         0|
|        BCN|                [3]|   ATH|         1|
|        MAD|          [5, 4, 6]|   JFK|8589934592|
|        LAX|       [8, 9, 7, 0]|   HND|8589934593|
+-----------+-------------------+------+----------+



In [9]:
exploded = trips \
    .select(col("row_id"),
            explode(col("internal_flight_ids")) \
               .alias("internal_flight_id"))
exploded.show()

+----------+------------------+
|    row_id|internal_flight_id|
+----------+------------------+
|         0|                 2|
|         0|                 1|
|         1|                 3|
|8589934592|                 5|
|8589934592|                 4|
|8589934592|                 6|
|8589934593|                 8|
|8589934593|                 9|
|8589934593|                 7|
|8589934593|                 0|
+----------+------------------+



In [10]:
#Join the exploded DataFrame with the flights table in order to add the public flight number.
exploded_with_flight_number = exploded \
    .join(flights, on="internal_flight_id")
exploded_with_flight_number.show()

+------------------+----------+--------------------+
|internal_flight_id|    row_id|public_flight_number|
+------------------+----------+--------------------+
|                 0|8589934593|              FR5763|
|                 7|8589934593|              AZ8844|
|                 6|8589934592|              CX4599|
|                 9|8589934593|              OP8777|
|                 5|8589934592|              NB4321|
|                 1|         0|              UT9586|
|                 3|         1|             RW35675|
|                 8|8589934593|              KH8851|
|                 2|         0|               B4325|
|                 4|8589934592|               LP656|
+------------------+----------+--------------------+



In [11]:
#Group by row ID and collect the variable public flight number into a list. This can be done with the built-in function collect_list().
collected = exploded_with_flight_number \
    .groupBy("row_id") \
    .agg(collect_list("public_flight_number") \
        .alias("public_flight_numbers"))
collected.show()

+----------+---------------------+
|    row_id|public_flight_numbers|
+----------+---------------------+
|8589934592| [CX4599, NB4321, ...|
|         0|      [UT9586, B4325]|
|8589934593| [FR5763, AZ8844, ...|
|         1|            [RW35675]|
+----------+---------------------+



In [12]:
#Join the collected DataFrame with the trips table and drop the row ID column.
trips_with_flight_numbers = collected \
    .join(trips, on="row_id") \
    .drop("row_id") \
    .drop("internal_flight_ids")
trips_with_flight_numbers.show()

+---------------------+-----------+------+
|public_flight_numbers|destination|origin|
+---------------------+-----------+------+
| [CX4599, NB4321, ...|        MAD|   JFK|
|      [UT9586, B4325]|        OPO|   PMI|
| [FR5763, AZ8844, ...|        LAX|   HND|
|            [RW35675]|        BCN|   ATH|
+---------------------+-----------+------+



In [15]:
# The above result is wrong
#Instead of using the explode() function on the internal_flight_ids column we must use the posexplode() built-in function, which creates two variables: one that is the exploded internal flight IDs and the other one that is the position in which each element appears in the array.
exploded = trips \
    .select(col("row_id"),
            posexplode(col("internal_flight_ids"))) \
    .withColumnRenamed("col", "internal_flight_id") \
    .withColumnRenamed("pos", "position")
exploded.show()

+----------+--------+------------------+
|    row_id|position|internal_flight_id|
+----------+--------+------------------+
|         0|       0|                 2|
|         0|       1|                 1|
|         1|       0|                 3|
|8589934592|       0|                 5|
|8589934592|       1|                 4|
|8589934592|       2|                 6|
|8589934593|       0|                 8|
|8589934593|       1|                 9|
|8589934593|       2|                 7|
|8589934593|       3|                 0|
+----------+--------+------------------+



In [16]:
#After adding the variable public_flight_number by joining the exploded DataFrame with the flights table (like before), the collect_list() has to be applied taking into account the position column that was created by the posexplode() function, which is done using the Window() function.
exploded_with_flight_number = exploded \
    .join(flights, on="internal_flight_id")


In [17]:
exploded_with_flight_number.show()

+------------------+----------+--------+--------------------+
|internal_flight_id|    row_id|position|public_flight_number|
+------------------+----------+--------+--------------------+
|                 0|8589934593|       3|              FR5763|
|                 7|8589934593|       2|              AZ8844|
|                 6|8589934592|       2|              CX4599|
|                 9|8589934593|       1|              OP8777|
|                 5|8589934592|       0|              NB4321|
|                 1|         0|       1|              UT9586|
|                 3|         1|       0|             RW35675|
|                 8|8589934593|       0|              KH8851|
|                 2|         0|       0|               B4325|
|                 4|8589934592|       1|               LP656|
+------------------+----------+--------+--------------------+



In [19]:
collected = exploded_with_flight_number \
    .withColumn("public_flight_numbers",
                collect_list("public_flight_number")
                    .over(Window \
                        .partitionBy("row_id") \
                        .orderBy("position") \
                        .rowsBetween(Window.unboundedPreceding,
                                     Window.unboundedFollowing))) \
    .select(["row_id", "public_flight_numbers"])
collected.show()

+----------+---------------------+
|    row_id|public_flight_numbers|
+----------+---------------------+
|8589934592| [NB4321, LP656, C...|
|8589934592| [NB4321, LP656, C...|
|8589934592| [NB4321, LP656, C...|
|         0|      [B4325, UT9586]|
|         0|      [B4325, UT9586]|
|8589934593| [KH8851, OP8777, ...|
|8589934593| [KH8851, OP8777, ...|
|8589934593| [KH8851, OP8777, ...|
|8589934593| [KH8851, OP8777, ...|
|         1|            [RW35675]|
+----------+---------------------+



In [20]:
#Note that the rows of the collected table are repeated. The last step to perform is to drop the duplicated rows of such table and join it with the original trips DataFrame.
trips_with_flight_numbers = collected \
    .dropDuplicates() \
    .join(trips, on="row_id") \
    .drop("row_id") \
    .drop("internal_flight_ids")
trips_with_flight_numbers.show()

+---------------------+-----------+------+
|public_flight_numbers|destination|origin|
+---------------------+-----------+------+
| [NB4321, LP656, C...|        MAD|   JFK|
|      [B4325, UT9586]|        OPO|   PMI|
| [KH8851, OP8777, ...|        LAX|   HND|
|            [RW35675]|        BCN|   ATH|
+---------------------+-----------+------+

