# Big programming assignment 1 - Sliding aggregation on NYC taxi data


## Configuring Spark

In [None]:
!pip install pyspark --quiet
!pip install -U -q PyDrive --quiet 
!apt install openjdk-8-jdk-headless &> /dev/null

In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

In [None]:
!lscpu

Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          2
On-line CPU(s) list:             0,1
Thread(s) per core:              2
Core(s) per socket:              1
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7B12
Stepping:                        0
CPU MHz:                         2249.998
BogoMIPS:                        4499.99
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       32 KiB
L1i cache:                       32 KiB
L2 cache:                        512 KiB
L3 cache:                        16 MiB
NUMA node0 CPU(s):               0,1
Vulnerability Itlb multihit:  

### Creating Spark Session

This time, rather than `SparkContext`, we will be using `SparkSession`. It is a wrapper around `SparkContext`, but also adds extra functionality and higher level API.

In [None]:
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder \
                    .master("local[*]") \
                    .config("spark.executor.memory", "4g") \
                    .config("spark.driver.memory", "1g") \
                    .appName("mlibs") \
                    .getOrCreate()

We can also get `SparkContext` from `SparkSession` and use it to access RDD API.

In [None]:
from pyspark import SparkContext
sc = spark.sparkContext

Note: pyspark documentation can be found [here](https://spark.apache.org/docs/3.1.2/api/python/reference/index.html). For example: list of `SparkContext` methods is [here](https://spark.apache.org/docs/3.1.2/api/python/reference/pyspark.html#spark-context-apis) and list of `SparkSession` methods is [here](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.SparkSession.html#pyspark.sql.SparkSession).


The following code allows us to inspect all the configuration. More about available configuration options and their default values can be found in [the documentation](https://spark.apache.org/docs/latest/configuration.html).

In [None]:
spark.sparkContext.getConf().getAll()

[('spark.driver.host', 'f52ddef0b9e0'),
 ('spark.executor.memory', '4g'),
 ('spark.executor.id', 'driver'),
 ('spark.app.id', 'local-1685357561405'),
 ('spark.driver.memory', '1g'),
 ('spark.driver.extraJavaOptions',
  '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.use

## Download data

In [None]:
!mkdir data

!wget https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet -O data/yellow_tripdata_2022_01.parquet

mkdir: cannot create directory ‘data’: File exists
--2023-05-29 10:52:42--  https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet
Resolving d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)... 13.249.90.174, 13.249.90.176, 13.249.90.209, ...
Connecting to d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)|13.249.90.174|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 38139949 (36M) [application/x-www-form-urlencoded]
Saving to: ‘data/yellow_tripdata_2022_01.parquet’


2023-05-29 10:52:43 (58.5 MB/s) - ‘data/yellow_tripdata_2022_01.parquet’ saved [38139949/38139949]



In [None]:
!du -hs ./data/*


37M	./data/yellow_tripdata_2022_01.parquet


In [None]:
N = 100

In [None]:
df_org = spark.read.parquet("data/yellow_tripdata_2022_01.parquet").cache()
df_org = df_org.limit(N)
df_org.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)



In [None]:
df_org.head(5)

[Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 35, 40), tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 0, 53, 29), passenger_count=2.0, trip_distance=3.8, RatecodeID=1.0, store_and_fwd_flag='N', PULocationID=142, DOLocationID=236, payment_type=1, fare_amount=14.5, extra=3.0, mta_tax=0.5, tip_amount=3.65, tolls_amount=0.0, improvement_surcharge=0.3, total_amount=21.95, congestion_surcharge=2.5, airport_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 33, 43), tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 0, 42, 7), passenger_count=1.0, trip_distance=2.1, RatecodeID=1.0, store_and_fwd_flag='N', PULocationID=236, DOLocationID=42, payment_type=1, fare_amount=8.0, extra=0.5, mta_tax=0.5, tip_amount=4.0, tolls_amount=0.0, improvement_surcharge=0.3, total_amount=13.3, congestion_surcharge=0.0, airport_fee=0.0),
 Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 53, 21), tpep_dropoff_datetime=datetime.dateti

## Prepare data for the assignment


In [None]:
selected_columns = [
    'tpep_dropoff_datetime',
    'tpep_pickup_datetime',
]
df = df_org.select(selected_columns)
df.head(5)

[Row(tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 0, 53, 29), tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 35, 40)),
 Row(tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 0, 42, 7), tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 33, 43)),
 Row(tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 1, 2, 19), tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 53, 21)),
 Row(tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 0, 35, 23), tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 25, 21)),
 Row(tpep_dropoff_datetime=datetime.datetime(2022, 1, 1, 1, 14, 20), tpep_pickup_datetime=datetime.datetime(2022, 1, 1, 0, 36, 48))]

In [None]:
org_rdd = df.rdd

In [None]:
from typing import Tuple
from datetime import datetime
from dataclasses import dataclass

Obj = Tuple[datetime, datetime]

@dataclass
class Point:
  obj: Obj

@dataclass
class Data(Point):
  pass

@dataclass
class Query(Point):
  pass

@dataclass
class Dummy(Point):
  pass

### Add queries and label placeholder

In [None]:
from datetime import datetime
from typing import Tuple, List, NewType
from random import uniform

Event = Tuple[datetime, datetime, str]

ADD_QUERY_P = 0.3

def add_query(x: Tuple[datetime, datetime]) -> List[Event]:
  pickup, dropoff = x
  res = [(pickup, dropoff, Data((pickup, dropoff)))]
  if uniform(0, 1) <= ADD_QUERY_P:
    res.append((pickup, dropoff, Query((pickup, dropoff))))
  return res

def add_label_placeholder(x):
  return ((), x)

rdd = org_rdd.flatMap(add_query).map(add_label_placeholder)

rdd.take(5)

[((),
  (datetime.datetime(2022, 1, 1, 0, 53, 29),
   datetime.datetime(2022, 1, 1, 0, 35, 40),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 53, 29), datetime.datetime(2022, 1, 1, 0, 35, 40))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 53, 29),
   datetime.datetime(2022, 1, 1, 0, 35, 40),
   Query(obj=(datetime.datetime(2022, 1, 1, 0, 53, 29), datetime.datetime(2022, 1, 1, 0, 35, 40))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 42, 7),
   datetime.datetime(2022, 1, 1, 0, 33, 43),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 42, 7), datetime.datetime(2022, 1, 1, 0, 33, 43))))),
 ((),
  (datetime.datetime(2022, 1, 1, 1, 2, 19),
   datetime.datetime(2022, 1, 1, 0, 53, 21),
   Data(obj=(datetime.datetime(2022, 1, 1, 1, 2, 19), datetime.datetime(2022, 1, 1, 0, 53, 21))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 35, 23),
   datetime.datetime(2022, 1, 1, 0, 25, 21),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 35, 23), datetime.datetime(2022, 1, 1, 0, 25, 21)))))]

#### Set artificial number of partitions if needed

In [None]:
T = 7
rdd = rdd.repartition(T)
rdd.getNumPartitions()

7

#### Add dummy rows if needed

In [None]:
from datetime import datetime

def add_dummy_rows(rdd):
  global N, T, M
  N = rdd.count()
  print(f"RDD containing {N} rows.")
  if N % T != 0:
    old_rdd = rdd
    n_dummies = T - (N % T)
    print(f"Adding {n_dummies} dummy rows")
    dummies = [
        ((), (datetime.max, datetime.max, Dummy((datetime.max, datetime.max))))
        for _ in range(n_dummies)
    ]
    dummies_rdd = sc.parallelize(dummies)
    rdd = rdd.union(dummies_rdd)
    rdd = rdd.repartition(T)
    N = rdd.count()
    M = N / T
  return rdd

rdd = add_dummy_rows(rdd)

RDD containing 127 rows.
Adding 6 dummy rows


In [None]:
N = rdd.count()
M = N / T

## TeraSort

### Defining constants

In [None]:
T = rdd.getNumPartitions()
T 

7

In [None]:
N = rdd.count()
N

133

In [None]:
M = N / T
M

19.0

In [None]:
from numpy import log

print(T * log(N * T))
# assert M >= T * log(N * T)

47.85381494093947


In [None]:
from numpy import log

P = 1 / M * log(N * T)
P

0.3598031198566877

### Debug utils

In [None]:
from typing import Any, Iterable, Tuple

def check_balance(partition: Iterable[Tuple[int, Any]]):
  res = set()
  length = 0
  for i, _ in partition:
    res.add(i)
    length += 1
  yield list(res), length


### Extract sort key

In [None]:
def extract_sort_key(obj):
  labels, (curr_dim, *_other_dims, point) = obj
  if isinstance(point, Dummy):
    return (('2',), curr_dim, 2)
  obj_type_weight = 0 if isinstance(point, Data) else 1
  res = (labels, curr_dim, obj_type_weight)
  return res

extract_sort_key((('pref1', 'pref2', 'pref3'), (1, 2, 3, 4, Data((1, 2, 3, 4))))) 

(('pref1', 'pref2', 'pref3'), 1, 0)

### Sort key ordering tests

In [None]:
extract_sort_key((('pref1', 'pref2', 'pref3'), (1, 2, 3, 4, Data((1, 2, 3, 4))))) < extract_sort_key((('pref1', 'pref2', 'pref3'), (1, 2, 3, 4, Query((1, 2, 3, 4)))))

True

In [None]:
extract_sort_key((('abc', 'pref2', 'pref3'), (1, 2, 3, 4, Data((1, 2, 3, 4))))) < extract_sort_key((('cde', 'pref2', 'pref3'), (1, 2, 3, 4, Query((1, 2, 3, 4)))))

True

In [None]:
extract_sort_key((('abc', 'pref2', 'pref3'), (1, 2, 3, 4, Data((1, 2, 3, 4))))) < extract_sort_key((('abc', 'pref2', 'pref3'), (10, 2, 3, 4, Query((1, 2, 3, 4)))))

True

In [None]:
extract_sort_key((('abc', 'pref2', 'pref3'), (10, 2, 3, 4, Data((1, 2, 3, 4))))) < extract_sort_key((('abc', 'pref2', 'pref3'), (1, 2, 3, 4, Query((1, 2, 3, 4)))))

False

In [None]:
extract_sort_key((('abc', 'pref2', 'pref3'), (1, 2, 3, 4, Query((1, 2, 3, 4))))) < extract_sort_key((('abc', 'pref2', 'pref3'), (1, 2, 3, 4, Data((1, 2, 3, 4)))))

False

### Sample - Round 1 Map-Shuffle

In [None]:
from typing import Callable, Iterable
from random import uniform


def get_sampler(p: float) -> Callable[[Iterable[Any]], Iterable[Any]]:
  def sampler(iterable):
    result = []
    for x in iterable:
      if uniform(0, 1) <= p:
        result.append(x)
    return result
  return sampler

sampled_rdd = rdd.mapPartitions(get_sampler(P)).repartition(1)

In [None]:
sampled_rdd.getNumPartitions()

1

In [None]:
sampled_rdd.glom().collect()[0][0]

((),
 (datetime.datetime(2022, 1, 1, 0, 25, 52),
  datetime.datetime(2022, 1, 1, 0, 13, 35),
  Data(obj=(datetime.datetime(2022, 1, 1, 0, 25, 52), datetime.datetime(2022, 1, 1, 0, 13, 35)))))

### Calculate boundaries - Round 1 Reduce


In [None]:
def calculate_boundaries(sampled_rdd):
  sampled_sorted_rdd = sampled_rdd.sortBy(extract_sort_key)
  sample_count = sampled_sorted_rdd.count()
  boundaries_rdd = sampled_sorted_rdd.zipWithIndex() \
      .filter(lambda x: (x[1] + 1) % (sample_count // T) == 0) \
      .map(lambda x: extract_sort_key(x[0]))
  boundaries = boundaries_rdd.collect()
  return boundaries

boundaries = calculate_boundaries(sampled_rdd)
boundaries

[((), datetime.datetime(2022, 1, 1, 0, 26, 57), 0),
 ((), datetime.datetime(2022, 1, 1, 0, 34, 7), 0),
 ((), datetime.datetime(2022, 1, 1, 0, 44, 49), 0),
 ((), datetime.datetime(2022, 1, 1, 0, 52, 51), 1),
 ((), datetime.datetime(2022, 1, 1, 1, 1, 35), 0),
 ((), datetime.datetime(2022, 1, 1, 1, 14, 20), 0),
 (('2',), datetime.datetime(9999, 12, 31, 23, 59, 59, 999999), 2)]

### Redistribute  - Round 2 Map-Shuffle

In [None]:
boundaries_broadcasted = sc.broadcast(boundaries)
boundaries_broadcasted

<pyspark.broadcast.Broadcast at 0x7feeaf724f40>

In [None]:
def round2_map_shuffle(partition):
    global boundaries_broadcasted
    boundaries = boundaries_broadcasted.value
    for obj in partition:
        key = extract_sort_key(obj)
        for j in range(T):
            if j == 0:
                if key < boundaries[j]:
                    yield (j, obj)
            elif j == T - 1:
                if key >= boundaries[j - 1]:
                    yield (j, obj)
            else:
                if boundaries[j - 1] <= key < boundaries[j]:
                    yield (j, obj)

shuffled_rdd = rdd.mapPartitions(round2_map_shuffle).partitionBy(T, lambda x: x)

In [None]:
shuffled_rdd.take(5)

[(0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 25, 52),
    datetime.datetime(2022, 1, 1, 0, 13, 35),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 25, 52), datetime.datetime(2022, 1, 1, 0, 13, 35)))))),
 (0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 20, 49),
    datetime.datetime(2022, 1, 1, 0, 15, 35),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 20, 49), datetime.datetime(2022, 1, 1, 0, 15, 35)))))),
 (0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 25, 53),
    datetime.datetime(2022, 1, 1, 0, 21, 1),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 25, 53), datetime.datetime(2022, 1, 1, 0, 21, 1)))))),
 (0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 20, 33),
    datetime.datetime(2022, 1, 1, 0, 15, 42),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 20, 33), datetime.datetime(2022, 1, 1, 0, 15, 42)))))),
 (0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 22, 45),
    datetime.datetime(2022, 1, 1, 0, 13, 4),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 22, 45), datetime.dateti

In [None]:
shuffled_rdd.mapPartitions(check_balance).collect()

[([0], 17), ([1], 12), ([2], 21), ([3], 17), ([4], 21), ([5], 26), ([6], 19)]

### Sort - Round 2 Reduce

In [None]:
sorted_rdd = shuffled_rdd \
  .map(lambda x: x[1]) \
  .mapPartitions(lambda x: sorted(x, key=extract_sort_key))

sorted_rdd.take(5)

[((),
  (datetime.datetime(2022, 1, 1, 0, 5, 29),
   datetime.datetime(2022, 1, 1, 0, 0, 44),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 8, 54),
   datetime.datetime(2022, 1, 1, 0, 0, 44),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 8, 54),
   datetime.datetime(2022, 1, 1, 0, 0, 44),
   Query(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 8, 58),
   datetime.datetime(2022, 1, 1, 0, 6, 10),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 14, 17),
   datetime.datetime(2022, 1, 1, 0, 6, 28),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 14, 17), datetime.datetime(2022, 1, 1, 0, 6, 28)))))]

In [None]:
# sorted_rdd.collect()

## Perfect Balance

### Compute Rank using Prefix Sums

#### Rank Map Shuffle

In [None]:
def partition_len(partition):
    total_weight = len(list(partition))
    yield total_weight

total_weights_rdd = sorted_rdd.mapPartitions(lambda x: partition_len(x))

total_weights = total_weights_rdd.collect()
total_weights

[17, 12, 21, 17, 21, 26, 19]

In [None]:
from functools import reduce

def reduce_func(acc, x):
  return acc + [acc[-1] + x]

prefix_sums = list(reduce(reduce_func, total_weights, [0]))[:-1]
prefix_sums

[0, 17, 29, 50, 67, 88, 114]

In [None]:
prefix_sums_broadcasted = sc.broadcast(prefix_sums)
prefix_sums_broadcasted

<pyspark.broadcast.Broadcast at 0x7feeaf727070>

#### Rank Reduce

In [None]:
def rank_reduce(partition_idx, partition):
    global prefix_sums_broadcasted
    prefix_sums = prefix_sums_broadcasted.value
    for i, obj in enumerate(partition):
      yield prefix_sums[partition_idx] + i, obj

ranked_rdd = sorted_rdd.mapPartitionsWithIndex(rank_reduce)
ranked_rdd.take(5)

[(0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 5, 29),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (1,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 54),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (2,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 54),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Query(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (3,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 58),
    datetime.datetime(2022, 1, 1, 0, 6, 10),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10)))))),
 (4,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 14, 17),
    datetime.datetime(2022, 1, 1, 0, 6, 28),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 14, 17), datetime.datetime(2022, 1, 1

In [None]:
# ranked_rdd.collect()

In [None]:
ranked_rdd.map(lambda x: x[0]).reduce(lambda x, y: max(x, y))

132

In [None]:
N / M

7.0

In [None]:
N

133

In [None]:
sum(total_weights)

133

In [None]:
ranked_rdd.count()

133

In [None]:
M

19.0

#### Perfect Balance Map-Shuffle

In [None]:
from pyspark import RDD

def shuffle(rank_obj):
  global M
  rank, obj = rank_obj
  j = int(rank // M)
  return (j, (rank, obj))

ranked_rdd: RDD = ranked_rdd
perfectly_balanced_rdd = ranked_rdd.map(shuffle).groupByKey(T, lambda x: x).flatMapValues(list)

perfectly_balanced_rdd.take(5)

[(0,
  (0,
   ((),
    (datetime.datetime(2022, 1, 1, 0, 5, 29),
     datetime.datetime(2022, 1, 1, 0, 0, 44),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))))),
 (0,
  (1,
   ((),
    (datetime.datetime(2022, 1, 1, 0, 8, 54),
     datetime.datetime(2022, 1, 1, 0, 0, 44),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44))))))),
 (0,
  (2,
   ((),
    (datetime.datetime(2022, 1, 1, 0, 8, 54),
     datetime.datetime(2022, 1, 1, 0, 0, 44),
     Query(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44))))))),
 (0,
  (3,
   ((),
    (datetime.datetime(2022, 1, 1, 0, 8, 58),
     datetime.datetime(2022, 1, 1, 0, 6, 10),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10))))))),
 (0,
  (4,
   ((),
    (datetime.datetime(2022, 1, 1, 0, 14, 17),
     datetime.datetime(2022, 1, 1, 0, 6, 28),
     Data(obj=(datetime.datetim

In [None]:
# perfectly_balanced_rdd.collect()

In [None]:
perfectly_balanced_rdd.mapPartitions(check_balance).collect()

[([0], 19), ([1], 19), ([2], 19), ([3], 19), ([4], 19), ([5], 19), ([6], 19)]

In [None]:
perfectly_balanced_rdd = perfectly_balanced_rdd.map(lambda x: x[1])
perfectly_balanced_rdd.take(5)

[(0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 5, 29),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (1,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 54),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (2,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 54),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Query(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (3,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 58),
    datetime.datetime(2022, 1, 1, 0, 6, 10),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10)))))),
 (4,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 14, 17),
    datetime.datetime(2022, 1, 1, 0, 6, 28),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 14, 17), datetime.datetime(2022, 1, 1

##  Multidimensional Interval Multiquery Processor

### Reduplicate

In [None]:
from math import log, ceil

TREE_HEIGHT = ceil(log(N, 2))
TREE_HEIGHT

8

In [None]:
def get_binary_notation_str(number):
  res = bin(number)[2:].rjust(TREE_HEIGHT, '0')
  return res

get_binary_notation_str(10)

'00001010'

In [None]:
def get_prefixes_followed_by_digit(rank: int, digit: str):
  bin_str = get_binary_notation_str(rank)
  res = []
  for i in range(len(bin_str)):
    if bin_str[i] == digit:
      res.append(bin_str[:i])
  return res

get_prefixes_followed_by_digit(10, '0')

['', '0', '00', '000', '00001', '0000101']

In [None]:
def reduplicate_row(rank_obj):
  rank, obj = rank_obj
  labels, (_cur_dim, *next_dims, point) = obj
  if isinstance(point, Dummy):
    return []
  digit = '0' if isinstance(point, Data) else '1'
  prefixes = get_prefixes_followed_by_digit(rank, digit)
  res = list(map(lambda prefix: ((*labels, prefix), tuple([*next_dims, point])), prefixes))
  return res

reduplicate_row((10, (('pref1', 'pref2'), (1, 2, 3, 4, Data((1, 2, 3, 4))))))

[(('pref1', 'pref2', ''), (2, 3, 4, Data(obj=(1, 2, 3, 4)))),
 (('pref1', 'pref2', '0'), (2, 3, 4, Data(obj=(1, 2, 3, 4)))),
 (('pref1', 'pref2', '00'), (2, 3, 4, Data(obj=(1, 2, 3, 4)))),
 (('pref1', 'pref2', '000'), (2, 3, 4, Data(obj=(1, 2, 3, 4)))),
 (('pref1', 'pref2', '00001'), (2, 3, 4, Data(obj=(1, 2, 3, 4)))),
 (('pref1', 'pref2', '0000101'), (2, 3, 4, Data(obj=(1, 2, 3, 4))))]

In [None]:
N, perfectly_balanced_rdd.count()

(133, 133)

In [None]:
perfectly_balanced_rdd.flatMap(reduplicate_row).count()

539

### Iterate sorting and reduplicating

In [None]:
def terra_sort_with_perfect_balance(rdd1):
  global N, T, M, P, prefix_sums_broadcasted, boundaries_broadcasted
  # add dummy rows
  rdd1 = add_dummy_rows(rdd1)
  N = rdd1.count()
  print(f"N: {N}")
  T = rdd1.getNumPartitions()
  print(f"T: {T}")
  M = N / T
  print(f"M: {M}")
  print(T * log(N * T))
  if M < T * log(N * T):
    print(f"Warning M < T * log(N * T): {M} < {T * log(N * T)}")
  P = 1 / M * log(N * T)
  print(f"P: {P}")

  # sort
  sampled_rdd = rdd1.mapPartitions(get_sampler(P)).repartition(1)
  boundaries = calculate_boundaries(sampled_rdd)
  print(f"Boundaries: {boundaries}")
  boundaries_broadcasted = sc.broadcast(boundaries)
  print(f"RDD ORG: {rdd1.take(5)}")
  shuffled_rdd = rdd1.mapPartitions(round2_map_shuffle).partitionBy(T, lambda x: x)
  print(f"SHUFFLED RDD: {shuffled_rdd.take(5)}")
  sorted_rdd = shuffled_rdd \
  .map(lambda x: x[1]) \
  .mapPartitions(lambda x: sorted(x, key=extract_sort_key))
  print(f"SORTED RDD: {sorted_rdd.take(5)}")


  # perfect balance
  total_weights_rdd = sorted_rdd.mapPartitions(lambda x: partition_len(x))
  total_weights = total_weights_rdd.collect()
  prefix_sums = list(reduce(reduce_func, total_weights, [0]))[:-1]
  print(f"PREFIX SUMS: {prefix_sums}")
  prefix_sums_broadcasted = sc.broadcast(prefix_sums)
  ranked_rdd = sorted_rdd.mapPartitionsWithIndex(rank_reduce)
  print(f"RANKED RDD: {ranked_rdd.take(5)}")
  perfectly_balanced_rdd = ranked_rdd.map(shuffle).groupByKey(T, lambda x: x).flatMapValues(list)
  print(f"PERFECTLY BALANCED: {perfectly_balanced_rdd.mapPartitions(check_balance).collect()}")
  perfectly_balanced_rdd = perfectly_balanced_rdd.map(lambda x: x[1])
  return perfectly_balanced_rdd


In [None]:
def reduplicate(rdd):
  global N, TREE_HEIGHT
  N = rdd.count()
  print(f"N: {N}")
  TREE_HEIGHT = ceil(log(N, 2))
  print(f"TREE_HEIGHT: {TREE_HEIGHT}")
  res = rdd.flatMap(reduplicate_row)
  return res

In [None]:
rdd.take(5)

[((),
  (datetime.datetime(2022, 1, 1, 0, 25, 52),
   datetime.datetime(2022, 1, 1, 0, 13, 35),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 25, 52), datetime.datetime(2022, 1, 1, 0, 13, 35))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 39, 38),
   datetime.datetime(2022, 1, 1, 0, 32, 27),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 39, 38), datetime.datetime(2022, 1, 1, 0, 32, 27))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 39, 38),
   datetime.datetime(2022, 1, 1, 0, 32, 27),
   Query(obj=(datetime.datetime(2022, 1, 1, 0, 39, 38), datetime.datetime(2022, 1, 1, 0, 32, 27))))),
 ((),
  (datetime.datetime(2022, 1, 1, 1, 1, 35),
   datetime.datetime(2022, 1, 1, 0, 43, 15),
   Data(obj=(datetime.datetime(2022, 1, 1, 1, 1, 35), datetime.datetime(2022, 1, 1, 0, 43, 15))))),
 ((),
  (datetime.datetime(2022, 1, 1, 0, 20, 49),
   datetime.datetime(2022, 1, 1, 0, 15, 35),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 20, 49), datetime.datetime(2022, 1, 1, 0, 15, 35)))))]

In [None]:
ordered_by_first_rdd = terra_sort_with_perfect_balance(rdd)
ordered_by_first_rdd.take(5)

RDD containing 133 rows.
N: 133
T: 7
M: 19.0
47.85381494093947
P: 0.3598031198566877
Boundaries: [((), datetime.datetime(2022, 1, 1, 0, 26, 57), 0), ((), datetime.datetime(2022, 1, 1, 0, 34, 7), 0), ((), datetime.datetime(2022, 1, 1, 0, 44, 49), 0), ((), datetime.datetime(2022, 1, 1, 0, 52, 51), 1), ((), datetime.datetime(2022, 1, 1, 1, 1, 35), 0), ((), datetime.datetime(2022, 1, 1, 1, 14, 20), 0), (('2',), datetime.datetime(9999, 12, 31, 23, 59, 59, 999999), 2)]
RDD ORG: [((), (datetime.datetime(2022, 1, 1, 0, 25, 52), datetime.datetime(2022, 1, 1, 0, 13, 35), Data(obj=(datetime.datetime(2022, 1, 1, 0, 25, 52), datetime.datetime(2022, 1, 1, 0, 13, 35))))), ((), (datetime.datetime(2022, 1, 1, 0, 39, 38), datetime.datetime(2022, 1, 1, 0, 32, 27), Data(obj=(datetime.datetime(2022, 1, 1, 0, 39, 38), datetime.datetime(2022, 1, 1, 0, 32, 27))))), ((), (datetime.datetime(2022, 1, 1, 0, 39, 38), datetime.datetime(2022, 1, 1, 0, 32, 27), Query(obj=(datetime.datetime(2022, 1, 1, 0, 39, 38), dat

[(0,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 5, 29),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (1,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 54),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (2,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 54),
    datetime.datetime(2022, 1, 1, 0, 0, 44),
    Query(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
 (3,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 8, 58),
    datetime.datetime(2022, 1, 1, 0, 6, 10),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10)))))),
 (4,
  ((),
   (datetime.datetime(2022, 1, 1, 0, 14, 17),
    datetime.datetime(2022, 1, 1, 0, 6, 28),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 14, 17), datetime.datetime(2022, 1, 1

In [None]:
label_first_rdd = reduplicate(ordered_by_first_rdd)
label_first_rdd.take(5), label_first_rdd.count()

N: 133
TREE_HEIGHT: 8


([(('',),
   (datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
  (('0',),
   (datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
  (('00',),
   (datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
  (('000',),
   (datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
  (('0000',),
   (datetime.datetime(2022, 1, 1, 0, 0, 44),
    Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44)))))],
 539)

In [None]:
label_first_sorted_rdd_with_rank = terra_sort_with_perfect_balance(label_first_rdd)
label_first_sorted_rdd_with_rank.take(5), label_first_sorted_rdd_with_rank.count()

RDD containing 539 rows.
N: 539
T: 7
M: 77.0
57.64938003975018
P: 0.1069561781813547
Boundaries: [(('',), datetime.datetime(2022, 1, 1, 0, 41, 32), 0), (('0',), datetime.datetime(2022, 1, 1, 0, 13, 17), 0), (('0',), datetime.datetime(2022, 1, 1, 0, 33, 50), 1), (('00011',), datetime.datetime(2022, 1, 1, 0, 27, 30), 1), (('01',), datetime.datetime(2022, 1, 1, 0, 46, 41), 0), (('010001',), datetime.datetime(2022, 1, 1, 0, 47, 55), 0), (('0110',), datetime.datetime(2022, 1, 1, 0, 56, 34), 1)]
RDD ORG: [(('',), (datetime.datetime(2022, 1, 1, 0, 0, 44), Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))), (('0',), (datetime.datetime(2022, 1, 1, 0, 0, 44), Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))), (('00',), (datetime.datetime(2022, 1, 1, 0, 0, 44), Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))), (('000',), (datetime.datetime(2022, 1, 1, 0, 0, 44), D

([(0,
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 0, 44),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
  (1,
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 0, 44),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44)))))),
  (2,
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 5, 26),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 29, 5), datetime.datetime(2022, 1, 1, 0, 5, 26)))))),
  (3,
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 5, 57),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 32, 31), datetime.datetime(2022, 1, 1, 0, 5, 57)))))),
  (4,
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 6, 10),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10))))))],
 539)

In [None]:
label_first_sorted_rdd = label_first_sorted_rdd_with_rank.map(lambda x: x[1])
label_first_sorted_rdd.take(5)

[(('',),
  (datetime.datetime(2022, 1, 1, 0, 0, 44),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
 (('',),
  (datetime.datetime(2022, 1, 1, 0, 0, 44),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
 (('',),
  (datetime.datetime(2022, 1, 1, 0, 5, 26),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 29, 5), datetime.datetime(2022, 1, 1, 0, 5, 26))))),
 (('',),
  (datetime.datetime(2022, 1, 1, 0, 5, 57),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 32, 31), datetime.datetime(2022, 1, 1, 0, 5, 57))))),
 (('',),
  (datetime.datetime(2022, 1, 1, 0, 6, 10),
   Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10)))))]

In [None]:
N

539

### Calculate counts


In [142]:
from numpy.lib.twodim_base import triu_indices
def compute_queries_and_count_last_label(partition_idx, partition):
  processing_first_label = triu_indices
  last_labels = None
  count = 0
  queries = []
  first_labels_objs = []
  for obj in partition:
    labels, (*_dims, point) = obj
    if last_labels != labels:
      if last_labels != None:
        processing_first_label = False
      last_labels = labels
      count = 0
    if processing_first_label:
      first_labels_objs.append(obj)
    if isinstance(point, Data):
      count += 1
    if isinstance(point, Query) and not processing_first_label:
      queries.append((point.obj, count))
  last_label_count = (last_labels, partition_idx, count)
  yield (last_label_count, first_labels_objs, queries)

initially_processed_rdd = label_first_sorted_rdd.mapPartitionsWithIndex(compute_queries_and_count_last_label)
initially_processed_rdd.take(2)

[((('',), 0, 77),
  [(('',),
    (datetime.datetime(2022, 1, 1, 0, 0, 44),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 5, 29), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 0, 44),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 54), datetime.datetime(2022, 1, 1, 0, 0, 44))))),
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 5, 26),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 29, 5), datetime.datetime(2022, 1, 1, 0, 5, 26))))),
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 5, 57),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 32, 31), datetime.datetime(2022, 1, 1, 0, 5, 57))))),
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 6, 10),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 8, 58), datetime.datetime(2022, 1, 1, 0, 6, 10))))),
   (('',),
    (datetime.datetime(2022, 1, 1, 0, 6, 28),
     Data(obj=(datetime.datetime(2022, 1, 1, 0, 14, 17), datetime.datetime(2022, 1, 1, 0, 6, 28))))),
   (('',),
    (datetime.datetim

In [143]:
last_label_counts_rdd = initially_processed_rdd.map(lambda x: x[0])
last_label_counts = last_label_counts_rdd.collect()
last_label_counts

[(('',), 0, 77),
 (('0',), 1, 48),
 (('000001',), 2, 2),
 (('0010001',), 3, 1),
 (('010',), 4, 1),
 (('011',), 5, 4),
 (('0111111',), 6, 1)]

In [None]:
last_label_counts_broadcasted = sc.broadcast(last_label_counts)
last_label_counts_broadcasted

<pyspark.broadcast.Broadcast at 0x7feeaf7e8550>

In [136]:
import operator

def finish_processing(partition_idx, first_labels_objs_queries):
  last_label_counts = last_label_counts_broadcasted.value
  first_labels_objs, queries = next(first_labels_objs_queries)

  first_label = first_labels_objs[0][0]
  count = sum(map(lambda x: x[2], filter(lambda x: x[0] == first_label and x[1] < partition_idx, last_label_counts)))
  for obj in first_labels_objs:
    labels, (*_dims, point) = obj
    if isinstance(point, Data):
      count += 1
    if isinstance(point, Query):
      queries.append((point.obj, count))
  yield queries

processed_rdd = initially_processed_rdd \
  .map(lambda x: (x[1], x[2])) \
  .mapPartitionsWithIndex(finish_processing) \
  .flatMap(lambda x: x) \
  .reduceByKey(operator.add) \
  .collect()

processed_rdd

[((datetime.datetime(2022, 1, 1, 1, 2, 50),
   datetime.datetime(2022, 1, 1, 0, 40, 18)),
  55),
 ((datetime.datetime(2022, 1, 1, 0, 38, 11),
   datetime.datetime(2022, 1, 1, 0, 30, 6)),
  30),
 ((datetime.datetime(2022, 1, 1, 0, 47, 36),
   datetime.datetime(2022, 1, 1, 0, 39, 46)),
  45),
 ((datetime.datetime(2022, 1, 1, 0, 44, 46),
   datetime.datetime(2022, 1, 1, 0, 41, 7)),
  39),
 ((datetime.datetime(2022, 1, 1, 0, 52, 56),
   datetime.datetime(2022, 1, 1, 0, 47, 55)),
  54),
 ((datetime.datetime(2022, 1, 1, 0, 52, 51),
   datetime.datetime(2022, 1, 1, 0, 37, 15)),
  47),
 ((datetime.datetime(2022, 1, 1, 1, 29, 25),
   datetime.datetime(2022, 1, 1, 1, 6, 32)),
  97),
 ((datetime.datetime(2022, 1, 1, 0, 53, 29),
   datetime.datetime(2022, 1, 1, 0, 35, 40)),
  46),
 ((datetime.datetime(2022, 1, 1, 0, 39, 38),
   datetime.datetime(2022, 1, 1, 0, 32, 27)),
  34),
 ((datetime.datetime(2022, 1, 1, 0, 32, 51),
   datetime.datetime(2022, 1, 1, 0, 27, 30)),
  23),
 ((datetime.datetime(202

### All in one place

In [150]:
def all_in_one_place(rdd):
  global last_label_counts_broadcasted

  ordered_by_first_rdd = terra_sort_with_perfect_balance(rdd)
  ordered_by_first_rdd.take(5)

  label_first_rdd = reduplicate(ordered_by_first_rdd)
  print(f"Label first rdd, count: {label_first_rdd.take(5), label_first_rdd.count()}")
  
  label_first_sorted_rdd_with_rank = terra_sort_with_perfect_balance(label_first_rdd)
  print(f"Label first sorted rdd with rank, count: {label_first_sorted_rdd_with_rank.take(5), label_first_sorted_rdd_with_rank.count()}")
  

  label_first_sorted_rdd = label_first_sorted_rdd_with_rank.map(lambda x: x[1])
  initially_processed_rdd = label_first_sorted_rdd.mapPartitionsWithIndex(compute_queries_and_count_last_label)
  last_label_counts_rdd = initially_processed_rdd.map(lambda x: x[0])
  last_label_counts = last_label_counts_rdd.collect()
  print(f"Last label counts: {last_label_counts}")
  last_label_counts_broadcasted = sc.broadcast(last_label_counts)
  last_label_counts_broadcasted

  processed_rdd = initially_processed_rdd \
    .map(lambda x: (x[1], x[2])) \
    .mapPartitionsWithIndex(finish_processing) \
    .flatMap(lambda x: x) \
    .reduceByKey(operator.add)

  return processed_rdd

In [151]:
def all_in_one_place_preprocessing(org_rdd):
  rdd = org_rdd.flatMap(add_query).map(add_label_placeholder)
  print(f"RDD: {rdd.take(5)}")
  return all_in_one_place(rdd)


In [152]:
processed_rdd = all_in_one_place_preprocessing(org_rdd)
processed_rdd.collect()

RDD: [((), (datetime.datetime(2022, 1, 1, 0, 53, 29), datetime.datetime(2022, 1, 1, 0, 35, 40), Data(obj=(datetime.datetime(2022, 1, 1, 0, 53, 29), datetime.datetime(2022, 1, 1, 0, 35, 40))))), ((), (datetime.datetime(2022, 1, 1, 0, 53, 29), datetime.datetime(2022, 1, 1, 0, 35, 40), Query(obj=(datetime.datetime(2022, 1, 1, 0, 53, 29), datetime.datetime(2022, 1, 1, 0, 35, 40))))), ((), (datetime.datetime(2022, 1, 1, 0, 42, 7), datetime.datetime(2022, 1, 1, 0, 33, 43), Data(obj=(datetime.datetime(2022, 1, 1, 0, 42, 7), datetime.datetime(2022, 1, 1, 0, 33, 43))))), ((), (datetime.datetime(2022, 1, 1, 1, 2, 19), datetime.datetime(2022, 1, 1, 0, 53, 21), Data(obj=(datetime.datetime(2022, 1, 1, 1, 2, 19), datetime.datetime(2022, 1, 1, 0, 53, 21))))), ((), (datetime.datetime(2022, 1, 1, 0, 35, 23), datetime.datetime(2022, 1, 1, 0, 25, 21), Data(obj=(datetime.datetime(2022, 1, 1, 0, 35, 23), datetime.datetime(2022, 1, 1, 0, 25, 21)))))]
RDD containing 127 rows.
Adding 6 dummy rows
N: 133
T: 7


[((datetime.datetime(2022, 1, 1, 1, 2, 50),
   datetime.datetime(2022, 1, 1, 0, 40, 18)),
  55),
 ((datetime.datetime(2022, 1, 1, 0, 38, 11),
   datetime.datetime(2022, 1, 1, 0, 30, 6)),
  30),
 ((datetime.datetime(2022, 1, 1, 0, 47, 36),
   datetime.datetime(2022, 1, 1, 0, 39, 46)),
  45),
 ((datetime.datetime(2022, 1, 1, 0, 44, 46),
   datetime.datetime(2022, 1, 1, 0, 41, 7)),
  39),
 ((datetime.datetime(2022, 1, 1, 0, 52, 56),
   datetime.datetime(2022, 1, 1, 0, 47, 55)),
  54),
 ((datetime.datetime(2022, 1, 1, 0, 52, 51),
   datetime.datetime(2022, 1, 1, 0, 37, 15)),
  47),
 ((datetime.datetime(2022, 1, 1, 1, 29, 25),
   datetime.datetime(2022, 1, 1, 1, 6, 32)),
  97),
 ((datetime.datetime(2022, 1, 1, 0, 53, 29),
   datetime.datetime(2022, 1, 1, 0, 35, 40)),
  46),
 ((datetime.datetime(2022, 1, 1, 0, 39, 38),
   datetime.datetime(2022, 1, 1, 0, 32, 27)),
  34),
 ((datetime.datetime(2022, 1, 1, 0, 32, 51),
   datetime.datetime(2022, 1, 1, 0, 27, 30)),
  23),
 ((datetime.datetime(202

### Testing

In [155]:
data = [(x, y, Data((x, y))) for x in range(100) for y in range(100)]

In [158]:
test = data + [(100, 100, Query((100, 100)))]
test_rdd = sc.parallelize(test).map(add_label_placeholder)
result_rdd = all_in_one_place(test_rdd)
result_rdd.collect()

RDD containing 10001 rows.
Adding 2 dummy rows
N: 10003
T: 7
M: 1429.0
78.09585333228345
P: 0.007807243160280261
Boundaries: [((), 8, 0), ((), 35, 0), ((), 37, 0), ((), 46, 0), ((), 53, 0), ((), 63, 0), ((), 88, 0)]
RDD ORG: [((), (0, 40, Data(obj=(0, 40)))), ((), (0, 41, Data(obj=(0, 41)))), ((), (0, 42, Data(obj=(0, 42)))), ((), (0, 43, Data(obj=(0, 43)))), ((), (0, 44, Data(obj=(0, 44))))]
SHUFFLED RDD: [(0, ((), (0, 40, Data(obj=(0, 40))))), (0, ((), (0, 41, Data(obj=(0, 41))))), (0, ((), (0, 42, Data(obj=(0, 42))))), (0, ((), (0, 43, Data(obj=(0, 43))))), (0, ((), (0, 44, Data(obj=(0, 44)))))]
SORTED RDD: [((), (0, 40, Data(obj=(0, 40)))), ((), (0, 41, Data(obj=(0, 41)))), ((), (0, 42, Data(obj=(0, 42)))), ((), (0, 43, Data(obj=(0, 43)))), ((), (0, 44, Data(obj=(0, 44))))]
PREFIX SUMS: [0, 800, 3500, 3700, 4600, 5300, 6300]
RANKED RDD: [(0, ((), (0, 40, Data(obj=(0, 40))))), (1, ((), (0, 41, Data(obj=(0, 41))))), (2, ((), (0, 42, Data(obj=(0, 42))))), (3, ((), (0, 43, Data(obj=(0,

[((100, 100), 10000)]

In [163]:
test = data + [(x, y, Query((x, y))) for x in range(0, 101, 10) for y in range(0, 101, 10)]
test_rdd = sc.parallelize(test).map(add_label_placeholder)
result_rdd = all_in_one_place(test_rdd)
result_rdd.collect()

RDD containing 10121 rows.
Adding 1 dummy rows
N: 10122
T: 7
M: 1446.0
78.17863690580535
P: 0.0077236353394393754
Boundaries: [((), 8, 0), ((), 13, 0), ((), 37, 0), ((), 43, 0), ((), 53, 0), ((), 64, 0), ((), 88, 0)]
RDD ORG: [((), (0, 40, Data(obj=(0, 40)))), ((), (0, 41, Data(obj=(0, 41)))), ((), (0, 42, Data(obj=(0, 42)))), ((), (0, 43, Data(obj=(0, 43)))), ((), (0, 44, Data(obj=(0, 44))))]
SHUFFLED RDD: [(0, ((), (0, 40, Data(obj=(0, 40))))), (0, ((), (0, 41, Data(obj=(0, 41))))), (0, ((), (0, 42, Data(obj=(0, 42))))), (0, ((), (0, 43, Data(obj=(0, 43))))), (0, ((), (0, 44, Data(obj=(0, 44)))))]
SORTED RDD: [((), (0, 40, Data(obj=(0, 40)))), ((), (0, 41, Data(obj=(0, 41)))), ((), (0, 42, Data(obj=(0, 42)))), ((), (0, 43, Data(obj=(0, 43)))), ((), (0, 44, Data(obj=(0, 44))))]
PREFIX SUMS: [0, 811, 1322, 3744, 4355, 5366, 6477]
RANKED RDD: [(0, ((), (0, 40, Data(obj=(0, 40))))), (1, ((), (0, 41, Data(obj=(0, 41))))), (2, ((), (0, 42, Data(obj=(0, 42))))), (3, ((), (0, 43, Data(obj=(0

[((100, 30), 3100),
 ((100, 50), 5100),
 ((80, 60), 4941),
 ((80, 70), 5751),
 ((90, 70), 6461),
 ((100, 70), 7100),
 ((80, 80), 6561),
 ((90, 80), 7371),
 ((80, 90), 7371),
 ((90, 90), 8281),
 ((30, 20), 651),
 ((20, 20), 441),
 ((30, 30), 961),
 ((20, 40), 861),
 ((0, 20), 21),
 ((0, 30), 31),
 ((0, 50), 51),
 ((0, 100), 100),
 ((50, 90), 4641),
 ((70, 100), 7100),
 ((40, 0), 41),
 ((50, 60), 3111),
 ((100, 40), 4100),
 ((80, 100), 8100),
 ((20, 0), 21),
 ((20, 30), 651),
 ((30, 40), 1271),
 ((20, 50), 1071),
 ((30, 70), 2201),
 ((10, 40), 451),
 ((0, 60), 61),
 ((0, 80), 81),
 ((0, 90), 91),
 ((50, 80), 4131),
 ((60, 100), 6100),
 ((40, 10), 451),
 ((70, 10), 781),
 ((40, 20), 861),
 ((50, 30), 1581),
 ((40, 50), 2091),
 ((60, 60), 3721),
 ((100, 0), 100),
 ((90, 20), 1911),
 ((90, 30), 2821),
 ((90, 100), 9100),
 ((20, 10), 231),
 ((30, 60), 1891),
 ((30, 80), 2511),
 ((30, 90), 2821),
 ((10, 0), 11),
 ((10, 10), 121),
 ((60, 80), 4941),
 ((60, 90), 5551),
 ((50, 100), 5100),
 ((70