In [None]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import matplotlib.pyplot as plt
import datetime as dt
import numpy as np

In [None]:
df = spark.read.csv("/user/s1919377/flights/*", header='true')
df = df.withColumn("firstseen",to_timestamp("firstseen", "yyyy-MM-dd HH:mm:ss")) \
       .withColumn("lastseen",to_timestamp("lastseen", "yyyy-MM-dd HH:mm:ss")) \
       .withColumn("day",to_timestamp("day", "yyyy-MM-dd HH:mm:ss")) \
       .withColumn("longitude_1",col("longitude_1").cast("float")) \
       .withColumn("longitude_2",col("longitude_2").cast("float")) \
       .withColumn("latitude_1",col("latitude_1").cast("float")) \
       .withColumn("latitude_2",col("latitude_2").cast("float")) \
       .withColumn("altitude_1",col("altitude_1").cast("float")) \
       .withColumn("altitude_2",col("altitude_2").cast("float"))

In [None]:
measures_df = spark.read.csv(f'file://{os.getcwd()}/international-travel-covid_filtered.csv', header='true')
measure_translation_list = ["No measures", "Screening", "Quarantine arrivals from\nhigh-risk regions", "Ban on high-risk regions", "Total border closure"]
date_range = ("2020-02-01",  "2021-12-22")

# Schiphol

In [None]:
schiphol_data = df.where(col('destination') == 'EHAM')
measures_nl = measures_df.where(col("Entity") == "Netherlands").where(col('day').between(*date_range))

In [None]:
measures_nl = measures_nl.withColumn('int_day', measures_nl.Day.cast('timestamp').cast('long'))
mnl = measures_nl.toPandas()
mnl['prev_itc'] = mnl['international_travel_controls'].shift()
mnl = mnl.loc[mnl['international_travel_controls'] != mnl['prev_itc']]
mnl = mnl.iloc[1: , :]

In [None]:
covid_cases_nl = spark.read.option("multiline","true").json(f'file://{os.getcwd()}/14_day_daily.json') \
    .withColumn("date", to_date("date", "yyyyMMdd")) \
    .withColumn("rate_14_day_per_100k", col("rate_14_day_per_100k").cast("float")) \
    .fillna(0, "rate_14_day_per_100k")

covid_cases_nl = covid_cases_nl.where(col("country") == "Netherlands").groupby(col("date")).mean("rate_14_day_per_100k").withColumnRenamed("avg(rate_14_day_per_100k)", "rate_14_day_per_100k").sort('date')

In [None]:
schiphol_data_grouped = schiphol_data \
    .groupby("day") \
    .count()
schiphol_data_joined = schiphol_data_grouped \
    .join(covid_cases_nl, [(to_date(schiphol_data.day) == covid_cases_nl.date)], "inner") \
.select("day", "rate_14_day_per_100k", "count") \
.sort(col("day"))

In [None]:
data_schiphol_range = schiphol_data_joined.where(col('day').between(*date_range))

data_schiphol_range_pd = data_schiphol_range \
    .toPandas() \
    .set_index("day")

rolling_avg_schiphol = data_schiphol_range_pd[["count", "rate_14_day_per_100k"]].rolling(20).mean()
rolling_avg_schiphol_inf = rolling_avg_schiphol[::-1]

In [None]:
x = data_schiphol_range_pd.index.values

plt.figure(figsize=(20,10)) 
plt.plot(x, rolling_avg_schiphol["count"], label="Flights to Schiphol")
plt.plot(x, rolling_avg_schiphol["rate_14_day_per_100k"], label="Cases per 100k people")
plt.scatter([], [], c='red',marker=r'$\rightarrow$',s=50, label='Stricter measures' )
plt.legend(loc='upper left', markerscale=2.5)

for index, row in mnl.iterrows():
    plt.annotate(
        measure_translation_list[int(row["international_travel_controls"])],
        xy=(dt.datetime.fromtimestamp(row["int_day"]), 0), xycoords='data',
        xytext=(-40, 50), textcoords='offset points',
        arrowprops=dict(arrowstyle="simple", facecolor="red", edgecolor="red"), size=12)
    
plt.ylim(0, 1000)
plt.show()

# LAX

In [None]:
lax_data = df.where(col('destination') == 'KLAX')
measures_usa = measures_df.where(col("Entity") == "United States").where(col('day').between(*date_range))

In [None]:
measures_usa = measures_usa.withColumn('int_day', measures_usa.Day.cast('timestamp').cast('long'))
musa = measures_usa.toPandas()
musa['prev_itc'] = musa['international_travel_controls'].shift()
musa = musa.loc[musa['international_travel_controls'] != musa['prev_itc']]
musa = musa.iloc[2: , :]

In [None]:
covid_cases_usa = spark.read.csv(f'file://{os.getcwd()}/usa-covid-data.csv', header="true") \
.where(col("state") == "CA") \
.withColumn("submission_date", to_date("submission_date", "MM/dd/yyyy")) \
.select("submission_date", "new_case")

#function to calculate number of seconds from number of days
days = lambda i: i * 86400
covid_cases_usa = covid_cases_usa.withColumn('submission_date', covid_cases_usa.submission_date.cast('timestamp'))
#create window by casting timestamp to long (number of seconds)
w = (Window.orderBy(F.col("submission_date").cast('long')).rangeBetween(-days(7), 0))
covid_cases_usa = covid_cases_usa.withColumn('rolling_average', F.avg("new_case").over(w)/39.5)

In [None]:
lax_data_grouped = lax_data \
    .groupby("day") \
    .count()
lax_data_joined = lax_data_grouped \
    .join(covid_cases_usa, [(to_date(lax_data.day) == covid_cases_usa.submission_date)], "inner") \
.drop("new_case", "submission_date") \
    .sort(col("day"))

In [None]:
data_lax_range = lax_data_joined.where(col('day').between(*date_range))

data_lax_range_pd = data_lax_range \
    .toPandas() \
    .set_index("day")

rolling_avg_lax = data_lax_range_pd[["count", "rolling_average"]].rolling(20).mean()
rolling_avg_lax_inf = rolling_avg_lax[::-1]

In [None]:
x = data_lax_range_pd.index.values
xlist = np.array([-60, -120, -20, -70, -50])
ylist = np.array([30, 30, 50, 40, 50])
clist = np.array(["red", "green", "red", "green", "red"])
musa.reset_index(drop=True, inplace=True)

plt.figure(figsize=(20,10)) 
plt.plot(x, rolling_avg_lax["count"], label="Flights to LAX")
plt.plot(x, rolling_avg_lax["rolling_average"], label="Cases per 100k people")
plt.scatter([], [], c='green',marker=r'$\rightarrow$',s=50, label='Looser measures' )
plt.scatter([], [], c='red',marker=r'$\rightarrow$',s=50, label='Stricter measures' )
plt.legend(loc='upper left', markerscale=2.5)

for index, row in musa.iterrows():
    plt.annotate(
        measure_translation_list[int(row["international_travel_controls"])],
        xy=(dt.datetime.fromtimestamp(row["int_day"]), rolling_avg_lax_inf.loc[row["Day"], 'count'].values[0]), xycoords='data',
        xytext=(xlist[index], ylist[index]), textcoords='offset points',
        arrowprops=dict(arrowstyle="simple", facecolor=clist[index], edgecolor=clist[index]), size=12)
    
plt.ylim(0, 1000)
plt.show()

# Dubai

In [None]:
dubai_data = df.where(col('destination') == 'OMDB')
measures_uae = measures_df.where(col("Entity") == "United Arab Emirates").where(col('day').between(*date_range))

In [None]:
measures_uae = measures_uae.withColumn('int_day', measures_uae.Day.cast('timestamp').cast('long'))
muae = measures_uae.toPandas()
muae['prev_itc'] = muae['international_travel_controls'].shift()
muae = muae.loc[muae['international_travel_controls'] != muae['prev_itc']]
muae = muae.iloc[1: , :]
muae.drop(muae.tail(1).index,inplace=True)

In [None]:
covid_cases_uae = spark.read.csv(f'file://{os.getcwd()}/uae-covid-data.csv', header="true", sep=";") \
.select("date", "new_cases_per_million") \
.withColumn("date", to_date("date", "yyyy-MM-dd")) \
.sort("date")

#function to calculate number of seconds from number of days
days = lambda i: i * 86400
covid_cases_uae = covid_cases_uae.withColumn('date', covid_cases_uae.date.cast('timestamp'))
#create window by casting timestamp to long (number of seconds)
w = (Window.orderBy(F.col("date").cast('long')).rangeBetween(-days(7), 0))
covid_cases_uae = covid_cases_uae.withColumn('rolling_average', F.avg("new_cases_per_million").over(w))

In [None]:
dubai_data_grouped = dubai_data \
    .groupby("day") \
    .count()
dubai_data_joined = dubai_data_grouped \
    .join(covid_cases_uae, [(to_date(dubai_data.day) == covid_cases_uae.date)], "inner") \
.drop("m_day", "new_cases_per_million", "date") \
    .sort(col("day"))

In [None]:
data_dubai_range = dubai_data_joined.where(col('day').between(*date_range))

data_dubai_range_pd = data_dubai_range \
    .toPandas() \
    .set_index("day")

rolling_avg_dubai = data_dubai_range_pd[["count", "rolling_average"]].rolling(20).mean()
rolling_avg_dubai_inf = rolling_avg_dubai[::-1]

In [None]:
x = data_dubai_range_pd.index.values
xlist = np.array([-55, 0, -60, -70, -50, -30, -150, -30, 0, 100])
ylist = np.array([50, 50, 30, 30, 60, 80, 20, 50, 60, 20])
clist = np.array(["red", "red", "green", "green", "red", "red", "green", "green", "red", "green"])
muae.reset_index(drop=True, inplace=True)

plt.figure(figsize=(20,10)) 
plt.plot(x, rolling_avg_dubai["count"], label="Flights to Dubai")
plt.plot(x, rolling_avg_dubai["rolling_average"], label="Cases per 100k people")
plt.scatter([], [], c='green',marker=r'$\rightarrow$',s=50, label='Looser measures' )
plt.scatter([], [], c='red',marker=r'$\rightarrow$',s=50, label='Stricter measures' )
plt.legend(loc='upper left', markerscale=2.5)
plt.ylim(0, 1000)

for index, row in muae.iterrows():
    an = plt.annotate(
        measure_translation_list[int(row["international_travel_controls"])],
        xy=(dt.datetime.fromtimestamp(row["int_day"]), rolling_avg_dubai_inf.loc[row["Day"], 'count'].values[0]), xycoords='data',
        xytext=(xlist[index], ylist[index]), textcoords='offset points',
        arrowprops=dict(arrowstyle="simple", facecolor=clist[index], edgecolor=clist[index]), size=12)
plt.show()