In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.window import Window
from pyspark.sql.functions import *
import os
import pandas as pd
import numpy as np
from scipy.spatial import distance

In [2]:
spark = SparkSession.builder.appName("XDaaAXDXD").getOrCreate()
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.1.2 pyspark-shell'

In [None]:
def tmp(x):
    return x[np.argmin(np.abs(x))]
    
def process_row_python(row):
    line = lines.loc[idx[row.line, :, :]]
    stops = line.groupby(['busstopId', 'busstopNr']).first().loc[:, ['lon', 'lat']]
    stops = np.sqrt(((stops - np.array([row.lon, row.lat]))**2).sum(axis=1)).sort_values().loc[idx[:, :]].head(4).index
    chosen = line.loc[idx[stops.get_level_values(0), stops.get_level_values(1), :]]
    chosen = chosen.time - row.time
    best_iter = chosen.abs().groupby('iter').min().sort_values().index[0]
    best_stops = chosen.loc[idx[:, :, best_iter]].abs().groupby('busstopId').min().index
    times = chosen.loc[idx[best_stops, :, best_iter]]
    return -times.groupby('busstopId').apply(tmp(x)).mean()

def process_row_spark(row):
    line = table.join(row, on='line')
    w = Window.partitionBy('busstopId', 'busstopNr').orderBy(col("time"))
    stops = line.withColumn("row", row_number().over(w)).filter(col("row") == 1).drop("row")
    stops = stops.withColumn('coords', array(stops.lon, stops.lat))
    stops = stops.withColumn('dists', distance_udf(stops.coords, stops.coords2))
    stops = stops.orderBy('dists').limit(4).select('busstopId', 'busstopNr')
    chosen = line.join(stops, on=['busstopId', 'busstopNr'])
    chosen = chosen.withColumn('time_diff', chosen.time - chosen.time2)
    chosen = chosen.withColumn('time_diff_abs', abs(chosen.time_diff))
    best_iter = chosen.groupby('iter').agg(min('time_diff_abs').alias('time_dist')).orderBy('time_dist').limit(1).select('iter')
    result = chosen.join(best_iter, on='iter').orderBy('time_diff_abs').limit(2).agg((-mean('time_diff')).alias('aaa')).select('aaa')
    result = result.aaa
    return result

distance_udf = udf(lambda x, row: float(distance.euclidean(x, row)), FloatType())

### Timetables

In [4]:
df = spark \
  .read \
  .format("kafka") \
  .option("kafka.bootstrap.servers", "instance-tram-1:9092") \
  .option("subscribe", "lines") \
  .load()

table = df.select(col("value").cast("string")) .alias("csv").select("csv.*")
table = table.selectExpr("split(value,',')[0] as busstopId" \
                 ,"split(value,',')[1] as busstopNr" \
                 ,"split(value,',')[2] as line" \
                 ,"split(value,',')[3] as direction" \
                 ,"split(value,',')[4] as lon" \
                 ,"split(value,',')[5] as lat" \
                         ,"split(value,',')[6] as time" \
                         ,"split(value,',')[7] as iter")

table = table.withColumn('busstopNr', table.busstopNr.cast(IntegerType()))\
.withColumn('line', table.line.cast(IntegerType()))\
.withColumn('lon', table.lon.cast(FloatType()))\
.withColumn('lat', table.lat.cast(FloatType()))\
.withColumn('time', table.time.cast(TimestampType()))\
.withColumn('iter', table.iter.cast(IntegerType()))
table = table.withColumn('time', 60*hour(table.time) + minute(table.time))
table = table.withColumn('coords', array(table.lon, table.lat))

In [5]:
table.show(20)

[Stage 0:>                                                          (0 + 1) / 1]

+---------+---------+----+--------------------+---------+---------+----+----+--------------------+
|busstopId|busstopNr|line|           direction|      lon|      lat|time|iter|              coords|
+---------+---------+----+--------------------+---------+---------+----+----+--------------------+
|     R-03|        0|  15|             Wołoska|  52.1885|20.999907| 223|   0|[52.1885, 20.999907]|
|     3240|        7|  15|         Samochodowa| 52.18876| 21.00332| 225|   0|[52.18876, 21.00332]|
|     R-04|        0|   4|"Zgrupowania AK "...|52.299137|20.934156| 225|   0|[52.299137, 20.93...|
|     3116|        3|  15|    Telewizja Polska|52.188824|21.007105| 226|   0|[52.188824, 21.00...|
|     R-04|        0|   1|"Zgrupowania AK "...|52.299137|20.934156| 227|   0|[52.299137, 20.93...|
|     3115|        3|  15|      Metro Wierzbno| 52.18887|  21.0114| 227|   0| [52.18887, 21.0114]|
|     6061|        5|   1|          Marymoncka|52.299557|20.935863| 228|   0|[52.299557, 20.93...|
|     6014

                                                                                

### Tram positions

In [4]:
df2 = spark \
  .readStream \
  .format("kafka") \
  .option("kafka.bootstrap.servers", "instance-tram-1:9092") \
  .option("subscribe", "tram_positions_processed") \
  .load()

schema = StructType([
    StructField('Lines', StringType()),
    StructField('Lon', FloatType()),
    StructField('VehicleNumber', StringType()),
    StructField('Time', TimestampType()),
    StructField('Lat', FloatType()),
    StructField('Brigade', StringType())
])

table2 = df2.selectExpr("CAST(key AS STRING)","CAST(value AS STRING)").select(from_json('value', schema).alias('temp')).select('temp.*')
table2 = table2.withColumnRenamed('Lines', 'line')\
.withColumnRenamed('Lon', 'lon2')\
.withColumnRenamed('Time', 'time2')\
.withColumnRenamed('Lat', 'lat2')

table2 = table2.withColumn('time2', 60*hour(table2.time2) + minute(table2.time2))
table2 = table2.withColumn('coords2', array(table2.lon2, table2.lat2))
# table2 = table2.select('line', 'lon2', 'lat2', 'time2', 'coords2')
table2 = table2.withColumn('line', table2.line.cast(IntegerType()))

In [9]:
idx = pd.IndexSlice

### Static example - Python

In [None]:
row = pd.Series([15, 52.2, 21, 700], index=['line', 'lon', 'lat', 'time'])
process_row_python(row)

### Static example - PySpark

In [4]:
data = [[15, 52.21901, 20.983505, 600]]
columns = ['line', 'lon2', 'lat2', 'time2']
row = spark.sparkContext.parallelize(data).toDF(columns)
row = row.withColumn('coords2', array(row.lon2, row.lat2))

                                                                                

In [None]:
process_row_spark(row)

### Attempts with stream processing

In [6]:
table3 = table.join(table2, on='line')
table3 = table3.withColumn('dists', distance_udf(table3.coords, table3.coords2))\
.withColumn('time_diff', table3.time - table3.time2)
table3 = table3.withColumn('time_diff_abs', abs(table3.time_diff))

In [7]:
w = Window.partitionBy('line', 'Brigade').orderBy(col("dists"), col('time_diff_abs'))

In [120]:
table3.withColumn("row", row_number().over(w)).filter(col("row") <= 4).drop("row")

AttributeError: 'GroupedData' object has no attribute 'limit'

In [107]:
table3.withColumn('delay', process_row(table3)).writeStream.format('console').start()

22/01/16 18:08:42 WARN org.apache.spark.sql.streaming.StreamingQueryManager: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-46d63f72-64c6-418f-9a91-8562bfe59d07. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
22/01/16 18:08:42 WARN org.apache.spark.sql.streaming.StreamingQueryManager: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.


<pyspark.sql.streaming.StreamingQuery at 0x7f5865577400>

                                                                                

-------------------------------------------
Batch: 0
-------------------------------------------
+----+---------+---------+---------+---+---+----+----+----+----+-----+-------+-----+
|line|busstopId|busstopNr|direction|lon|lat|time|iter|lon2|lat2|time2|coords2|delay|
+----+---------+---------+---------+---+---+----+----+----+----+-----+-------+-----+
+----+---------+---------+---------+---+---+----+----+----+----+-----+-------+-----+



In [6]:
# df_kafka = table\
#     .withColumn('value', to_json(struct(table["`dt`"], table["`main.temp`"], table["`main.pressure`"], table["`main.humidity`"], table["`visibility`"],
#                                        table["`wind.speed`"], table["`clouds.all`"], table["`rain.3h`"], table["`snow.3h`"], table["`pop`"])))