In [1]:
import json
import os

import dask.dataframe as dd
import geopandas as gpd
import numpy as np
import pandas as pd
import pyproj
import pyspark.sql.functions as F
from dask.diagnostics import ProgressBar
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import regexp_extract
from pyspark.sql.types import StructType, StructField, LongType
from scipy.spatial import cKDTree
from shapely.geometry import Point

from util import load_data

DATA_DIR = "data"
PROCESSED_CSV_DIR = os.path.join(DATA_DIR, "processed")



In [2]:
spark = (
    SparkSession.builder.appName("Airline Twitter Sentiment Analysis")
    .config("spark.driver.memory", "8g")
    .getOrCreate()
)

In [3]:
PROCESSED_CSV = [
    os.path.join(PROCESSED_CSV_DIR, file)
    for file in os.listdir(PROCESSED_CSV_DIR)
    if file.endswith(".csv")
]

In [4]:
df = spark.read.csv(PROCESSED_CSV, header=True, inferSchema=True)

In [5]:
df.show()

+--------+---------+-------------------+-------+----------+-------+-------------------+-----------+------+----------+--------------+-----------------+--------------+---------------+-------------+--------------------+-----------+-------------------+--------+--------------------+--------------------+---------------+---------------+
|_country| _unit_id|        _created_at|_golden|       _id|_missed|        _started_at|   _channel|_trust|_worker_id|           _ip|airline_sentiment|       airline|           name|retweet_count|                text|tweet_coord|      tweet_created|tweet_id|      tweet_location|       user_timezone|negativereason1|negativereason2|
+--------+---------+-------------------+-------+----------+-------+-------------------+-----------+------+----------+--------------+-----------------+--------------+---------------+-------------+--------------------+-----------+-------------------+--------+--------------------+--------------------+---------------+---------------+
|   

In [6]:
df.printSchema()

root
 |-- _country: string (nullable = true)
 |-- _unit_id: integer (nullable = true)
 |-- _created_at: timestamp (nullable = true)
 |-- _golden: boolean (nullable = true)
 |-- _id: integer (nullable = true)
 |-- _missed: boolean (nullable = true)
 |-- _started_at: timestamp (nullable = true)
 |-- _channel: string (nullable = true)
 |-- _trust: double (nullable = true)
 |-- _worker_id: integer (nullable = true)
 |-- _ip: string (nullable = true)
 |-- airline_sentiment: string (nullable = true)
 |-- airline: string (nullable = true)
 |-- name: string (nullable = true)
 |-- retweet_count: integer (nullable = true)
 |-- text: string (nullable = true)
 |-- tweet_coord: string (nullable = true)
 |-- tweet_created: timestamp (nullable = true)
 |-- tweet_id: double (nullable = true)
 |-- tweet_location: string (nullable = true)
 |-- user_timezone: string (nullable = true)
 |-- negativereason1: string (nullable = true)
 |-- negativereason2: string (nullable = true)



In [7]:
df = df.withColumn(
    "latitude",
    F.when(F.col("tweet_coord").isNotNull(), F.split(F.col("tweet_coord"), ",")[0])
    .otherwise(None)
    .cast("double"),
).withColumn(
    "longitude",
    F.when(F.col("tweet_coord").isNotNull(), F.split(F.col("tweet_coord"), ",")[1])
    .otherwise(None)
    .cast("double"),
)

In [8]:
pattern = r"([\+-]?\d+(?:\.\d+)?),\s?([\+-]?\d+(?:\.\d+)?)"
df = df.withColumn(
    "latitude",
    F.when(
        F.col("tweet_coord").isNull() & F.col("tweet_location").isNotNull(),
        regexp_extract(
            F.col("tweet_location"), pattern, 1
        ),
    ).otherwise(F.col("latitude")).cast("double")
).withColumn(
    "longitude",
    F.when(
        F.col("tweet_coord").isNull() & F.col("tweet_location").isNotNull(),
        regexp_extract(
            F.col("tweet_location"), pattern, 2
        ),
    ).otherwise(F.col("longitude")).cast("double")
)

In [9]:
# Keep this for the join.
indexed = df.withColumn(
    "index",
    F.monotonically_increasing_id(),
)

In [8]:
df = df.dropDuplicates()

In [None]:
pddf = df.toPandas()

In [11]:
latlong_df = indexed.filter(F.col("latitude").isNotNull() & F.col("longitude").isNotNull())

In [12]:
latlong_pddf = latlong_df.select("index", "latitude", "longitude").toPandas()

In [13]:
with open(os.path.join(DATA_DIR, "cities.json"), "r", encoding="utf-8") as f:
    cities = json.load(f)
    cities = pd.DataFrame.from_records(cities, index="id")

with open(os.path.join(DATA_DIR, "countries.json"), "r", encoding="utf-8") as f:
    countries = json.load(f)
    for country in countries:
        del country["timezones"]
        del country["translations"]
    countries = pd.DataFrame.from_records(countries, index="id")
    countries["longitude"] = cities["longitude"].astype(float)
    countries["latitude"] = cities["latitude"].astype(float)

with open(os.path.join(DATA_DIR, "states.json"), "r", encoding="utf-8") as f:
    states = json.load(f)
    states = pd.DataFrame.from_records(states, index="id")

In [12]:
def get_distance(p1: Point, p2: Point, crs: int = 4326) -> float:
    geo_df = gpd.GeoDataFrame({"geometry": [p1, p2]})
    crs = pyproj.crs.CRS.from_user_input(crs)
    geo_df.set_crs(crs, inplace=True)
    p1 = geo_df.iloc[0].values
    p2 = geo_df.iloc[1].values
    geod = geo_df.crs.get_geod()
    dist = geod.inv(p1.x, p1.y, p2.x, p2.y)[2] # distance in metres
    return dist

In [13]:
def get_nearest_neighbour(row: pd.Series, countries_df: pd.DataFrame, max_dist: float = 25e5) -> str:
    origin = Point(row["longitude"], row["latitude"])
    closest_country = None
    closest_distance = np.inf
    for index, country_row in countries_df.iterrows():
        lat = country_row["latitude"]
        long = country_row["longitude"]
        country_coords = Point(long, lat)
        dist = get_distance(origin, country_coords)
        if dist < closest_distance and dist <= max_dist:
            closest_distance = dist
            closest_country = country_row["iso3"]
    return closest_country

In [29]:
def ckdnearest(gdA, gdB, threshold):
    nA = np.array(list(gdA.geometry.apply(lambda x: (x.x, x.y))))
    nB = np.array(list(gdB.geometry.apply(lambda x: (x.x, x.y))))
    btree = cKDTree(nB)
    print(
        np.isnan(nA).any(),
        np.isnan(nB).any(),
        np.any(np.isinf(nA)),
        np.any(np.isinf(nB)),
    )
    dist, idx = btree.query(nA, k=1, distance_upper_bound=threshold)
    gdB_nearest = gdB.iloc[[i - 1 for i in idx]].reset_index(drop=True)
    gdA = gdA.reset_index(drop=True)
    gdA["distance"] = dist
    return gdA[gdA["distance"] <= threshold], gdB_nearest[gdA["distance"] <= threshold]

In [15]:
gdf = gpd.GeoDataFrame(latlong_pddf, geometry=latlong_pddf.apply(
    lambda row: Point(row["longitude"], row["latitude"]), axis=1
))
countries_gdf = gpd.GeoDataFrame(countries, geometry=countries.apply(
    lambda row: Point(row["longitude"], row["latitude"]), axis=1
))

gdf.set_crs(epsg=4326, inplace=True)
countries_gdf.set_crs(epsg=4326, inplace=True)

Unnamed: 0_level_0,name,iso3,iso2,numeric_code,phone_code,capital,currency,currency_name,currency_symbol,tld,...,region,region_id,subregion,subregion_id,nationality,latitude,longitude,emoji,emojiU,geometry
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,Afghanistan,AFG,AF,004,93,Kabul,AFN,Afghan afghani,؋,.af,...,Asia,3,Southern Asia,14,Afghan,42.50779,1.52109,🇦🇫,U+1F1E6 U+1F1EB,POINT (1.52109 42.50779)
2,Aland Islands,ALA,AX,248,+358-18,Mariehamn,EUR,Euro,€,.ax,...,Europe,4,Northern Europe,18,Aland Island,42.57205,1.48453,🇦🇽,U+1F1E6 U+1F1FD,POINT (1.48453 42.57205)
3,Albania,ALB,AL,008,355,Tirana,ALL,Albanian lek,Lek,.al,...,Europe,4,Southern Europe,16,Albanian,42.56760,1.59756,🇦🇱,U+1F1E6 U+1F1F1,POINT (1.59756 42.56760)
4,Algeria,DZA,DZ,012,213,Algiers,DZD,Algerian dinar,دج,.dz,...,Africa,1,Northern Africa,1,Algerian,42.57952,1.65362,🇩🇿,U+1F1E9 U+1F1FF,POINT (1.65362 42.57952)
5,American Samoa,ASM,AS,016,+1-684,Pago Pago,USD,US Dollar,$,.as,...,Oceania,5,Polynesia,22,American Samoan,42.53474,1.58014,🇦🇸,U+1F1E6 U+1F1F8,POINT (1.58014 42.53474)
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
243,Wallis and Futuna Islands,WLF,WF,876,681,Mata Utu,XPF,CFP franc,₣,.wf,...,Oceania,5,Polynesia,22,"Wallis and Futuna, Wallisian or Futunan",41.50000,20.33333,🇼🇫,U+1F1FC U+1F1EB,POINT (20.33333 41.50000)
244,Western Sahara,ESH,EH,732,212,El-Aaiun,MAD,Moroccan Dirham,MAD,.eh,...,Africa,1,Northern Africa,1,"Sahrawi, Sahrawian, Sahraouian",39.91667,20.08333,🇪🇭,U+1F1EA U+1F1ED,POINT (20.08333 39.91667)
245,Yemen,YEM,YE,887,967,Sanaa,YER,Yemeni rial,﷼,.ye,...,Asia,3,Western Asia,11,Yemeni,40.58333,20.91667,🇾🇪,U+1F1FE U+1F1EA,POINT (20.91667 40.58333)
246,Zambia,ZMB,ZM,894,260,Lusaka,ZMW,Zambian kwacha,ZK,.zm,...,Africa,1,Eastern Africa,4,Zambian,41.75000,20.33333,🇿🇲,U+1F1FF U+1F1F2,POINT (20.33333 41.75000)


In [48]:
near_df, near_countries = ckdnearest(
    gdf, countries_gdf, threshold=30
)

False False False False


In [49]:
near_countries.reset_index(inplace=True, names="join_idx")

In [50]:
near_df.reset_index(inplace=True, names="join_idx")

In [51]:
# Join the dfs by index
result = pd.merge(near_df, near_countries[["join_idx", "iso3"]], how="left")

In [52]:
result

Unnamed: 0,join_idx,index,latitude,longitude,geometry,distance,iso3
0,0,27,40.965130,-73.872957,POINT (-73.87296 40.96513),26.255009,FSM
1,2,155,40.965130,-73.872957,POINT (-73.87296 40.96513),26.255009,FSM
2,4,235,42.706796,-71.210754,POINT (-71.21075 42.70680),26.769662,FSM
3,5,330,38.961334,-77.006119,POINT (-77.00612 38.96133),26.174363,FSM
4,6,374,18.463449,-70.003608,POINT (-70.00361 18.46345),8.212341,FSM
...,...,...,...,...,...,...,...
6655,8455,266287973006,41.233943,-74.396666,POINT (-74.39667 41.23394),26.736237,FSM
6656,8456,266287973075,40.965130,-73.872957,POINT (-73.87296 40.96513),26.255009,FSM
6657,8457,266287973076,40.965130,-73.872957,POINT (-73.87296 40.96513),26.255009,FSM
6658,8458,266287973104,50.079998,14.441111,POINT (14.44111 50.08000),9.272751,PHL


In [53]:
latlong_pddf = result[["index", "iso3"]]

In [54]:
latlong_df = spark.createDataFrame(latlong_pddf)

In [55]:
indexed = indexed.join(latlong_df, on="index", how="left")

In [56]:
indexed.show()

+-----+--------+---------+-------------------+-------+----------+-------+-------------------+-----------+------+----------+--------------+-----------------+--------------+---------------+-------------+--------------------+-----------+-------------------+--------+--------------------+--------------------+---------------+---------------+--------+---------+----+
|index|_country| _unit_id|        _created_at|_golden|       _id|_missed|        _started_at|   _channel|_trust|_worker_id|           _ip|airline_sentiment|       airline|           name|retweet_count|                text|tweet_coord|      tweet_created|tweet_id|      tweet_location|       user_timezone|negativereason1|negativereason2|latitude|longitude|iso3|
+-----+--------+---------+-------------------+-------+----------+-------+-------------------+-----------+------+----------+--------------+-----------------+--------------+---------------+-------------+--------------------+-----------+-------------------+--------+-------------

In [57]:
missing_iso3_df = indexed.filter(F.col("iso3").isNull())

In [58]:
missing_iso3_df.count()

373247

In [63]:
missing_iso3_df = missing_iso3_df.withColumn(
    "city",
    F.when(F.col("tweet_location").isNotNull(), F.split(F.col("tweet_location"), ",")[0]).otherwise(None)
)

In [64]:
cities.head()

Unnamed: 0_level_0,name,state_id,state_code,state_name,country_id,country_code,country_name,latitude,longitude,wikiDataId
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
52,Ashkāsham,3901,BDS,Badakhshan,1,AF,Afghanistan,36.68333,71.53333,Q4805192
68,Fayzabad,3901,BDS,Badakhshan,1,AF,Afghanistan,37.11664,70.58002,Q156558
78,Jurm,3901,BDS,Badakhshan,1,AF,Afghanistan,36.86477,70.83421,Q10308323
84,Khandūd,3901,BDS,Badakhshan,1,AF,Afghanistan,36.95127,72.318,Q3290334
115,Rāghistān,3901,BDS,Badakhshan,1,AF,Afghanistan,37.66079,70.67346,Q2670909


In [60]:
cities_df = spark.createDataFrame(cities)

In [70]:
cities_df = cities_df.dropDuplicates(["name"])

In [73]:
cities_df.show()

+--------------------+--------+----------+--------------------+----------+------------+------------+------------+------------+----------+
|                name|state_id|state_code|          state_name|country_id|country_code|country_name|    latitude|   longitude|wikiDataId|
+--------------------+--------+----------+--------------------+----------+------------+------------+------------+------------+----------+
|       's-Heerenberg|    2611|        GE|          Gelderland|       156|          NL| Netherlands| 51.87670000|  6.25877000|   Q425810|
|            't Kabel|    2612|        NH|       North Holland|       156|          NL| Netherlands| 52.25610000|  4.64970000|  Q2152439|
|          25 de Mayo|    3217|        FD|             Florida|       235|          UY|     Uruguay|-34.18917000|-56.33944000|   Q218472|
|            Aagtdorp|    2612|        NH|       North Holland|       156|          NL| Netherlands| 52.68890000|  4.70190000|   Q584873|
|             Aarberg|    1645|   

In [80]:
missing_iso3_df = missing_iso3_df.withColumn(
    "twitter_username",
    F.col("name")
).drop("name")

In [81]:
missing_iso3_df = missing_iso3_df.join(
    other=cities_df.select("name", "country_name"),
    on=(F.col("city") == F.col("name")),
    how="left",
).drop("name")

In [84]:
missing_iso3_df.show()

+-----+--------+---------+-------------------+-------+----------+-------+-------------------+---------+------+----------+---------------+-----------------+--------------+-------------+--------------------+--------------------+-------------------+--------+--------------------+--------------------+---------------+---------------+--------+---------+----+--------------------+----------------+-------------+
|index|_country| _unit_id|        _created_at|_golden|       _id|_missed|        _started_at| _channel|_trust|_worker_id|            _ip|airline_sentiment|       airline|retweet_count|                text|         tweet_coord|      tweet_created|tweet_id|      tweet_location|       user_timezone|negativereason1|negativereason2|latitude|longitude|iso3|                city|twitter_username| country_name|
+-----+--------+---------+-------------------+-------+----------+-------+-------------------+---------+------+----------+---------------+-----------------+--------------+-------------+----

In [87]:
countries_df = spark.createDataFrame(countries)

In [88]:
missing_iso3_df = missing_iso3_df.drop("iso3").join(
    other=countries_df.select("iso3", "name"),
    on=(F.col("country_name") == F.col("name")),
    how="left"
).drop("name")

In [89]:
missing_iso3_df.show()

+-----+--------+---------+-------------------+-------+----------+-------+-------------------+------------+------+----------+---------------+-----------------+--------------+-------------+--------------------+--------------------+-------------------+--------+--------------------+--------------------+---------------+---------------+--------+---------+--------------------+----------------+-------------+----+
|index|_country| _unit_id|        _created_at|_golden|       _id|_missed|        _started_at|    _channel|_trust|_worker_id|            _ip|airline_sentiment|       airline|retweet_count|                text|         tweet_coord|      tweet_created|tweet_id|      tweet_location|       user_timezone|negativereason1|negativereason2|latitude|longitude|                city|twitter_username| country_name|iso3|
+-----+--------+---------+-------------------+-------+----------+-------+-------------------+------------+------+----------+---------------+-----------------+--------------+---------

In [90]:
missing_iso3_df.columns, latlong_df.columns

(['index',
  '_country',
  '_unit_id',
  '_created_at',
  '_golden',
  '_id',
  '_missed',
  '_started_at',
  '_channel',
  '_trust',
  '_worker_id',
  '_ip',
  'airline_sentiment',
  'airline',
  'retweet_count',
  'text',
  'tweet_coord',
  'tweet_created',
  'tweet_id',
  'tweet_location',
  'user_timezone',
  'negativereason1',
  'negativereason2',
  'latitude',
  'longitude',
  'city',
  'twitter_username',
  'country_name',
  'iso3'],
 ['index', 'iso3'])

In [91]:
filled_iso3_df = missing_iso3_df.filter(F.col("iso3").isNotNull()).select("index", "iso3")

In [92]:
merged_iso3_df = filled_iso3_df.union(latlong_df)

In [94]:
merged_indexed_df = indexed.drop("iso3").join(
    merged_iso3_df, on="index", how="left"
).dropna(subset=["iso3"])

In [95]:
merged_indexed_df.show()

+-----+--------+---------+-------------------+-------+----------+-------+-------------------+--------------+------+----------+---------------+-----------------+--------------+---------------+-------------+--------------------+-----------+-------------------+--------+--------------------+--------------------+---------------+---------------+---------+----------+----+
|index|_country| _unit_id|        _created_at|_golden|       _id|_missed|        _started_at|      _channel|_trust|_worker_id|            _ip|airline_sentiment|       airline|           name|retweet_count|                text|tweet_coord|      tweet_created|tweet_id|      tweet_location|       user_timezone|negativereason1|negativereason2| latitude| longitude|iso3|
+-----+--------+---------+-------------------+-------+----------+-------+-------------------+--------------+------+----------+---------------+-----------------+--------------+---------------+-------------+--------------------+-----------+-------------------+------

In [96]:
merged_indexed_df.count()

201216

In [97]:
merged_indexed_df.drop("index", "tweet_location", "tweet_coord", "latitude", "longitude").repartition(1).write.csv(
    "./data/task_3/", header=True, mode="overwrite", quote='"'
)