## Imports

In [None]:
from functools import reduce
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql.functions import *
import datetime
import math
import builtins

## Query B

In [None]:
FMT = '%Y-%m-%d %H:%M:%S'
INITIAL_LATITUDE = 41.477182778
INITIAL_LONGITUDE = -74.916578
FINAL_LATITUDE = 40.129715978
FINAL_LONGITUDE = -73.120778

def getTimestamp(d):
    d_dt = datetime.datetime.strptime(d, FMT)
    date = str(d_dt.date())
    hour = str(d_dt.hour)
    return datetime.datetime.strptime(date + " " + hour + ":30:00", "%Y-%m-%d %H:%M:%S")

def getTime30(d):
    d_dt = datetime.datetime.strptime(d, FMT)
    date = str(d_dt.date())
    hour = str(d_dt.hour)
    return datetime.datetime.strptime(date + " " + hour + ":00:00", "%Y-%m-%d %H:%M:%S")

def getTime15(d):
    d_dt = datetime.datetime.strptime(d, FMT)
    date = str(d_dt.date())
    hour = str(d_dt.hour)
    return datetime.datetime.strptime(date + " " + hour + ":15:00", "%Y-%m-%d %H:%M:%S")

def getTop10(array):
    return array[0:10]

In [None]:
spark = SparkSession.builder.master('local[*]').appName('taxi').getOrCreate()
sc = spark.sparkContext

In [None]:
try:
    lines = sc.textFile('sorted_data.csv')
    dataRows = lines.filter(lambda line : len(line) > 0) \
                    .filter(lambda line : float(line.split(',')[4]) > 0) \
                    .map(lambda line : line.split(',')) \
                    .map(lambda el : Row(id = el[0], \
                                         day = datetime.datetime.strptime(el[2], FMT).strftime('%A'), \
                                         p_date = datetime.datetime.strptime(el[2], FMT).date(), \
                                         d_date = datetime.datetime.strptime(el[3], FMT).date(), \
                                         p_h = datetime.datetime.strptime(el[2], FMT).hour, \
                                         d_h = datetime.datetime.strptime(el[3], FMT).hour, \
                                         p_dt = el[2], \
                                         d_dt = el[3], \
                                         p_lon = builtins.round(float(el[6]), 2), \
                                         p_lat = builtins.round(float(el[7]), 2), \
                                         d_lon = builtins.round(float(el[8]), 2), \
                                         d_lat = builtins.round(float(el[9]), 2), \
                                         fare = float(el[11]), \
                                         tip = float(el[14])))
                                         
    getTimestamp = udf(getTimestamp, TimestampType())
    getTime30 = udf(getTime30, TimestampType())
    getTime15 = udf(getTime15, TimestampType())
    getTop10 = udf(getTop10, ArrayType(StringType()))
        
    dataDF = spark.createDataFrame(dataRows)
    
    dataDF = dataDF.where(dataDF.p_lon >= INITIAL_LONGITUDE) \
                   .where(dataDF.p_lon <= FINAL_LONGITUDE) \
                   .where(dataDF.d_lon >= INITIAL_LONGITUDE) \
                   .where(dataDF.d_lon <= FINAL_LONGITUDE) \
                   .where(dataDF.p_lat >= FINAL_LATITUDE) \
                   .where(dataDF.p_lat <= INITIAL_LATITUDE) \
                   .where(dataDF.d_lat >= FINAL_LATITUDE) \
                   .where(dataDF.d_lat <= INITIAL_LATITUDE)
    
    tripsDown = dataDF.where(dataDF.d_dt <= getTimestamp(dataDF.d_dt))
    tripsDown = tripsDown.where(dataDF.d_dt > getTime30(dataDF.d_dt))
    
    tripsDown30 = tripsDown.select("id", "day", "d_date", "d_h", "d_dt", "d_lat", "d_lon") \
                           .dropDuplicates(subset=["id", "d_date", "d_h"])
    
    tripsUp = dataDF.where(dataDF.p_dt <= getTimestamp(dataDF.d_dt))
    tripsUp = tripsUp.where(dataDF.p_dt > getTime30(dataDF.d_dt))
    
    tripsUp = tripsUp.select("id", "p_date", "p_h", "p_dt") \
                     .dropDuplicates(subset=["id", "p_date", "p_h"]) \
                     .withColumnRenamed("id", "taxi_id")
    
    joined = tripsDown30.join(tripsUp, tripsDown30.id == tripsUp.taxi_id, "leftouter") \
                      .select("id", "day", "d_h", "p_h", "d_dt", "p_dt", "d_lat", "d_lon")
    
    taxisState = joined.select(joined.id, joined.day, joined.d_h, joined.d_lat, joined.d_lon, \
                              when(joined.p_dt > joined.d_dt, 1).otherwise(0).alias("taxi_busy"))
    
    taxisState = taxisState.filter(taxisState.taxi_busy == 0)
    
    taxisState = taxisState.groupBy("day", "d_h", "d_lat", "d_lon").count() \
                           .select(col("day"), col("d_h"), col("d_lat"), col("d_lon"), col("count").alias("taxi_free"))
    
    trips15 = tripsDown.where(dataDF.d_dt <= getTimestamp(dataDF.d_dt))
    trips15 = trips15.where(dataDF.d_dt > getTime15(dataDF.d_dt))
    
    profit = trips15.groupBy("day", "p_h", "p_lat", "p_lon") \
                    .avg("fare", "tip") \
                    .select(col("day").alias("p_day"), col("p_h"), col("p_lat"), col("p_lon"), \
                           (col("avg(fare)") + col("avg(tip)")).alias("profit"))
    
    cond = [profit.p_day == taxisState.day, profit.p_h == taxisState.d_h, profit.p_lat == taxisState.d_lat, profit.p_lon == taxisState.d_lon]
    
    profitable = profit.join(taxisState, cond, "leftouter") \
                       .select(col("p_day").alias('weekday'), col("p_h").alias('hour'), col("p_lat"), col("p_lon"), (col("profit") / \
                        when(taxisState.taxi_free > 0, (taxisState.taxi_free + 1)).otherwise(1)).alias("p")) \
                       .orderBy("p", ascending=False)
    
    final = profitable.withColumn('col_aux', concat(lit("("), col('p_lat'), lit(" "), col('p_lon'), lit(")"))) \
                      .drop(col('p_lat')).drop(col('p_lon')) \
                      .groupBy(["weekday", "hour"]).agg(collect_list(col("col_aux")).alias("area")) \
                      .withColumn('areas', getTop10(col('area'))).drop(col("area")) \
                      .orderBy(["weekday", "hour"], ascending=[True, True])
    
    final.show(truncate=False)
    
    sc.stop()
    
except err:
    print(err)
    sc.stop()