In [None]:
# pyspark --py-files ./v0.7.0.zip --packages anguenot/pyspark-cassandra:0.9.0,org.apache.spark:spark-streaming-kafka-0-8_2.11:2.4.3 --conf spark.cassandra.connection.host=127.0.0.1
import findspark
findspark.init()

import os
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
from pyspark.sql import Row, SparkSession
from pyspark.streaming.kafka import KafkaUtils
import pyspark
import json
import sys
import operator
import datetime
from dateutil import parser
from pyspark_cassandra import streaming

sc.setLogLevel("WARN")
ssc = StreamingContext(sc, 60)

# defining the checkpoint directory
ssc.checkpoint("/tmp")

def getSparkSessionInstance(sparkConf):
    if ('sparkSessionSingletonInstance' not in globals()):
        globals()['sparkSessionSingletonInstance'] = SparkSession\
            .builder\
            .config(conf=sparkConf)\
            .getOrCreate()
    return globals()['sparkSessionSingletonInstance']

kafkaStream = KafkaUtils.createStream(ssc, '172.30.7.114:2181,172.30.6.85:2181,172.30.9.172:2181', 'spark-streaming', {'test':1})
parsed = kafkaStream.map(lambda v: json.loads(v[1]))

# define the update function
def update_Q1_1(newValues, runningCount):
    if runningCount is None:
       runningCount = 0
    return sum(newValues, runningCount)

# define the update function
def update_Q1_2(newValues, runningState):
    if runningState is None:
       runningState = [0,0]
    runningState = list(runningState)
    for i in newValues:
        runningState[0] += i[0]
        runningState[1] += i[1]
    return tuple(runningState)

def process1_1(time, rdd):
    print("========= %s =========" % str(time))

    try:
        if (not rdd.take(1)):
            return

        spark = getSparkSessionInstance(rdd.context.getConf())
        rowRdd = rdd.map(lambda w: Row(id=w[0], count=w[1], airport=w[2]))
        df = spark.createDataFrame(rowRdd)
        df.write\
            .format("org.apache.spark.sql.cassandra")\
            .mode('overwrite')\
            .options(table="q1_1", keyspace="test")\
            .option("confirm.truncate","true")\
            .save()
    except:
        print("Oops!",sys.exc_info()[0],"occured.")
        pass

def process1_2(time, rdd):
    print("========= %s =========" % str(time))

    try:
        if (not rdd.take(1)):
            return
        
        spark = getSparkSessionInstance(rdd.context.getConf())
        rowRdd = rdd.map(lambda w: Row(id=w[0], arrival_delay=w[1], carrier=w[2]))
        df = spark.createDataFrame(rowRdd)
        df.write\
            .format("org.apache.spark.sql.cassandra")\
            .mode('overwrite')\
            .options(table="q1_2", keyspace="test")\
            .option("confirm.truncate","true")\
            .save()        

    except:
        print("Oops!",sys.exc_info()[0],"occured.")
        pass


def update_Q2(newValues, runningState):
    if runningState is None:
       runningState = [0,0]
    runningState = list(runningState)
    for i in newValues:
        runningState[0] += i[0]
        runningState[1] += i[1]
    return tuple(runningState)

def update_Q3(newValues, runningState):
    if runningState is None:
       runningState = []
    if newValues is None:
        return None
    return runningState + newValues

def process2_1(time, rdd):
    print("========= %s =========" % str(time))

    try:
        if (not rdd.take(1)):
            return
        
        spark = getSparkSessionInstance(rdd.context.getConf())
        rowRdd = rdd.map(lambda w: Row(airport=w[0], departure_delay=w[1][1], carrier=w[1][0]))
        df = spark.createDataFrame(rowRdd)
        df.write\
            .format("org.apache.spark.sql.cassandra")\
            .mode('overwrite')\
            .options(table="q2_1", keyspace="test")\
            .option("confirm.truncate","true")\
            .save()         
        
    except:
        print("Oops!",sys.exc_info()[0],"occured.")
        pass

def process2_2(time, rdd):
    print("========= %s =========" % str(time))

    try:
        if (not rdd.take(1)):
            return
        
        spark = getSparkSessionInstance(rdd.context.getConf())
        rowRdd = rdd.map(lambda w: Row(airport=w[0], departure_delay=w[1][1], destination_airport=w[1][0]))
        df = spark.createDataFrame(rowRdd)
        df.write\
            .format("org.apache.spark.sql.cassandra")\
            .mode('overwrite')\
            .options(table="q2_2", keyspace="test")\
            .option("confirm.truncate","true")\
            .save()         

    except:
        print("Oops!",sys.exc_info()[0],"occured.")
        pass

def process2_3(time, rdd):
    print("========= %s =========" % str(time))

    try:
        if (not rdd.take(1)):
            return
        
        spark = getSparkSessionInstance(rdd.context.getConf())
        rowRdd = rdd.map(lambda w: Row(airport=w[0], destination_airport=w[1], arrival_delay=w[3], carrier=w[2]))
        df = spark.createDataFrame(rowRdd)
        df.write\
            .format("org.apache.spark.sql.cassandra")\
            .mode('overwrite')\
            .options(table="q2_3", keyspace="test")\
            .option("confirm.truncate","true")\
            .save()         

    except:
        print("Oops!",sys.exc_info()[0],"occured.")
        pass   

# ======================================================================================
# Question 1.1 - Rank the top 10 most popular airports by numbers of flights to/from the airport.
# ======================================================================================
# CREATE TABLE test.q1_1 (id int, count int, airport text, PRIMARY KEY(id, count, airport));
# SELECT * FROM test.q1_1 where id = 1 ORDER BY count DESC LIMIT 10;
def Q1_1(parsed):
    parsed.flatMap(lambda v: [v["Origin"], v["Dest"]]) \
       .map(lambda airport: (airport, 1)) \
       .reduceByKey(lambda a, b: a + b) \
       .updateStateByKey(update_Q1_1) \
       .map(lambda v: (1, v[1], v[0])) \
       .foreachRDD(process1_1)

# ======================================================================================
# Question 1.2 - Rank the top 10 airlines by on-time arrival performance.
# ======================================================================================
# CREATE TABLE test.q1_2 (id int, arrival_delay float, carrier text, PRIMARY KEY(id, arrival_delay, carrier));
# SELECT * FROM test.q1_2 where id = 1 ORDER BY arrival_delay ASC LIMIT 10;
def Q1_2(parsed):
    parsed.map(lambda v: (v["UniqueCarrier"], [v["UniqueCarrier"], float(v["ArrDelay"]), 1])) \
        .reduceByKey(lambda a,b: (a[0], a[1] + b[1], a[2] + b[2])) \
        .map(lambda v: (v[0], (v[1][1], v[1][2]))) \
        .updateStateByKey(update_Q1_2) \
        .map(lambda v: (1, v[1][0] / v[1][1], v[0])) \
        .foreachRDD(process1_2)

# ======================================================================================
# Question 2.1 - For each airport X, 
#                rank the top-10 carriers in decreasing order of on-time departure performance from X.
# ======================================================================================
# CREATE TABLE test.q2_1 (airport text, departure_delay float, carrier text, PRIMARY KEY(airport, departure_delay, carrier));
# SELECT * FROM test.q2_1 WHERE airport='SRQ' ORDER BY departure_delay LIMIT 10;
def Q2_1(parsed):
    parsed.map(lambda v: ('{}_{}'.format(v["Origin"], v["UniqueCarrier"]), (float(v["DepDelay"]), 1))) \
        .reduceByKey(lambda a,b: (a[0] + b[0], a[1] + b[1])) \
        .updateStateByKey(update_Q2) \
        .map(lambda v: (v[0].split("_")[0], (v[0].split("_")[1], v[1][0] / v[1][1]))) \
        .foreachRDD(process2_1)
    
# ======================================================================================
# Question 2.2 - For each source airport X, 
#                rank the top-10 destination airports in decreasing order of on-time departure performance from X.
# ======================================================================================
# CREATE TABLE test.q2_2 (airport text, departure_delay float, destination_airport text, PRIMARY KEY(airport, departure_delay, destination_airport));
# SELECT * FROM test.q2_2 WHERE airport='SRQ' ORDER BY departure_delay LIMIT 10;
def Q2_2(parsed):
    parsed.map(lambda v: ('{}_{}'.format(v["Origin"], v["Dest"]), (float(v["DepDelay"]), 1))) \
        .reduceByKey(lambda a,b: (a[0] + b[0], a[1] + b[1])) \
        .updateStateByKey(update_Q2) \
        .map(lambda v: (v[0].split("_")[0], (v[0].split("_")[1], v[1][0] / v[1][1]))) \
        .foreachRDD(process2_2)

# ======================================================================================
# Question 2.3 - For each source-destination pair X-Y, 
#                rank the top-10 carriers in decreasing order of on-time arrival performance at Y from X.
# ======================================================================================
# CREATE TABLE test.q2_3 (airport text, destination_airport text, arrival_delay float, carrier text, PRIMARY KEY((airport, destination_airport), arrival_delay, carrier));
# SELECT * FROM test.q2_3 WHERE airport='LGA' and destination_airport='BOS' ORDER BY arrival_delay LIMIT 10;
def Q2_3(parsed):
    parsed.map(lambda v: ('{}_{}_{}'.format(v["Origin"], v["Dest"], v["UniqueCarrier"]), (float(v["ArrDelay"]), 1))) \
        .reduceByKey(lambda a,b: (a[0] + b[0], a[1] + b[1])) \
        .updateStateByKey(update_Q2) \
        .map(lambda v: (v[0].split("_")[0], v[0].split("_")[1], v[0].split("_")[2], v[1][0] / v[1][1])) \
        .foreachRDD(process2_3)

# ======================================================================================
# Question 3.2
# ======================================================================================

def increase_2_days(date):
    dt = parser.parse(date)
    future_date = dt + datetime.timedelta(days=2)
    return future_date.strftime("%Y-%m-%d")

def process3_2(time, rdd):
    print("========= %s =========" % str(time))

    try:
        if (not rdd.take(1)):
            return
        
        spark = getSparkSessionInstance(rdd.context.getConf())
        rowRdd = rdd.map(lambda w: Row(carrier=w["UniqueCarrier"],\
                                       origin=w["Origin"],\
                                       destination=w["Dest"],\
                                       arrival_delay=float(w["ArrDelay"]),\
                                       date=w["FlightDate"],\
                                       future_date=increase_2_days(w["FlightDate"]),\
                                       time=w["CRSDepTime"]))
        wordsDataFrame = spark.createDataFrame(rowRdd)
        wordsDataFrame.createOrReplaceTempView("flights")
        sqlCommand = \
    "SELECT a.origin as lag1_origin, " + \
           "a.destination as lag1_destination, " + \
           "a.carrier as lag1_carrier, " + \
           "a.arrival_delay as lag1_arrival_delay, " + \
           "a.date as lag1_date, " + \
           "a.time as lag1_time, " + \
           "b.origin as lag2_origin, " + \
           "b.destination as lag2_destination, " + \
           "b.carrier as lag2_carrier, " + \
           "b.arrival_delay as lag2_arrival_delay, " + \
           "b.date as lag2_date, " + \
           "b.time as lag2_time " + \
    "FROM flights AS a " + \
    "INNER JOIN flights AS b ON a.destination = b.origin AND a.future_date=b.date AND a.time<1200 AND b.time>=1200" 
        spark.sql(sqlCommand).write\
            .format("org.apache.spark.sql.cassandra")\
            .mode('overwrite')\
            .options(table="q3_2", keyspace="test")\
            .option("confirm.truncate","true")\
            .save()         

    except:
        print("Oops!",sys.exc_info()[0],"occured.")
        pass 

def Q3_2(parsed):
    parsed \
        .filter(lambda v: v["FlightDate"].find("2008-") >= 0) \
        .map(lambda v: ('{}_{}_{}_{}' \
            .format(v["Origin"], v["Dest"], v["FlightDate"], "PM" if int(v["CRSDepTime"]) > 1200 else "AM"), v)) \
        .reduceByKey(lambda a,b: b if float(a["ArrDelay"]) >= float(b["ArrDelay"]) else a) \
        .map(lambda v: v[1]) \
        .foreachRDD(process3_2)
    

parsed.count().pprint()
# Q1_1(parsed)
# Q1_2(parsed)
# Q2_1(parsed)
# Q2_2(parsed)
# Q2_3(parsed)
Q3_2(parsed)
    
# Start the stream
ssc.start()
ssc.awaitTermination()


-------------------------------------------
Time: 2019-08-03 21:24:00
-------------------------------------------
461156

-------------------------------------------
Time: 2019-08-03 21:25:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:26:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:27:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:28:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:29:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:30:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:31:00
-------------------------------------------

-------------------------------------------
Time: 2019-08-03 21:32:00
---