In [1]:
import marimo as mo

In [2]:
from pyspark.sql import SparkSession

spark = (SparkSession.builder
    .appName("DataLoader")
    .master("local[*]") # use all local cores
    .getOrCreate())

df = spark.read.parquet("datasets/bing_covid-19_data.parquet")

In [3]:
policy_df = spark.read.parquet("datasets/covid_policy_tracker.parquet")

In [4]:
policy_df.show(vertical = True)

-RECORD 0-----------------------------------------------------
 countryname                           | Aruba                
 countrycode                           | ABW                  
 date                                  | 2020-01-01           
 c1_school_closing                     | 0.0                  
 c1_flag                               | NULL                 
 c2_workplace_closing                  | 0.0                  
 c2_flag                               | NULL                 
 c3_cancel_public_events               | 0.0                  
 c3_flag                               | NULL                 
 c4_restrictions_on_gatherings         | 0.0                  
 c4_flag                               | NULL                 
 c5_close_public_transport             | 0.0                  
 c5_flag                               | NULL                 
 c6_stay_at_home_requirements          | 0.0                  
 c6_flag                               | NULL          

# Exploring the datatypes and data

In [5]:
df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- updated: date (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- confirmed_change: integer (nullable = true)
 |-- deaths: integer (nullable = true)
 |-- deaths_change: short (nullable = true)
 |-- recovered: integer (nullable = true)
 |-- recovered_change: integer (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- iso2: string (nullable = true)
 |-- iso3: string (nullable = true)
 |-- country_region: string (nullable = true)
 |-- admin_region_1: string (nullable = true)
 |-- iso_subdivision: string (nullable = true)
 |-- admin_region_2: string (nullable = true)
 |-- load_time: timestamp (nullable = true)



In [6]:
df.show(n=20)

+------+----------+---------+----------------+------+-------------+---------+----------------+--------+---------+----+----+--------------+--------------+---------------+--------------+--------------------+
|    id|   updated|confirmed|confirmed_change|deaths|deaths_change|recovered|recovered_change|latitude|longitude|iso2|iso3|country_region|admin_region_1|iso_subdivision|admin_region_2|           load_time|
+------+----------+---------+----------------+------+-------------+---------+----------------+--------+---------+----+----+--------------+--------------+---------------+--------------+--------------------+
|338995|2020-01-21|      262|            NULL|     0|         NULL|     NULL|            NULL|    NULL|     NULL|NULL|NULL|     Worldwide|          NULL|           NULL|          NULL|2025-05-10 02:05:...|
|338996|2020-01-22|      313|              51|     0|            0|     NULL|            NULL|    NULL|     NULL|NULL|NULL|     Worldwide|          NULL|           NULL|       

In [7]:
reduced_df = df.drop("load_time")

In [8]:
reduced_df.show(20)

+------+----------+---------+----------------+------+-------------+---------+----------------+--------+---------+----+----+--------------+--------------+---------------+--------------+
|    id|   updated|confirmed|confirmed_change|deaths|deaths_change|recovered|recovered_change|latitude|longitude|iso2|iso3|country_region|admin_region_1|iso_subdivision|admin_region_2|
+------+----------+---------+----------------+------+-------------+---------+----------------+--------+---------+----+----+--------------+--------------+---------------+--------------+
|338995|2020-01-21|      262|            NULL|     0|         NULL|     NULL|            NULL|    NULL|     NULL|NULL|NULL|     Worldwide|          NULL|           NULL|          NULL|
|338996|2020-01-22|      313|              51|     0|            0|     NULL|            NULL|    NULL|     NULL|NULL|NULL|     Worldwide|          NULL|           NULL|          NULL|
|338997|2020-01-23|      578|             265|     0|            0|     NUL

In [9]:
reduced_df.describe().show(vertical=True)

-RECORD 0--------------------------------
 summary          | count                
 id               | 4766736              
 confirmed        | 4766736              
 confirmed_change | 4761213              
 deaths           | 4668438              
 deaths_change    | 4657186              
 recovered        | 1114073              
 recovered_change | 1104531              
 latitude         | 4765761              
 longitude        | 4765761              
 iso2             | 4759261              
 iso3             | 4759261              
 country_region   | 4766736              
 admin_region_1   | 4530381              
 iso_subdivision  | 3932598              
 admin_region_2   | 3779699              
-RECORD 1--------------------------------
 summary          | mean                 
 id               | 1.4054495640649387E8 
 confirmed        | 169962.092957529     
 confirmed_change | 358.2463617989785    
 deaths           | 2426.586312595348    
 deaths_change    | 4.018275628244

## Get the number of unique strings per column

In [10]:
from pyspark.sql.types import StringType
from pyspark.sql.functions import countDistinct, col

def count_distinct_strings(df):
    """
    Counts the number of distinct values for each string-type column in a DataFrame, 
    excluding nulls.

    Parameters
    ----------
    df : pyspark.sql.DataFrame
        The input DataFrame containing various columns, including string-type columns.

    Returns
    -------
    pyspark.sql.DataFrame
        A DataFrame with one row where each column corresponds to the number of distinct 
        non-null values in the original string-type column, with each column renamed to 
        indicate the distinct count (e.g., 'column_name_distinct').
    """
    # get the string type columns from the schema
    string_cols = [field.name for field in df.schema.fields if isinstance(field.dataType, StringType)]

    # filter out nulls
    # must filter seperately, filter only accepts a single condition
    filtered_df = df
    for c in string_cols:
        filtered_df = filtered_df.filter(col(c).isNotNull())

    # count the distinct strings, and rename the columns once in the iteration
    result = filtered_df.agg(
        *[countDistinct(col(c)).alias(f"{c}_distinct") for c in string_cols]
    )

    return result

distinct_counts = count_distinct_strings(df)
distinct_counts.show()

+-------------+-------------+-----------------------+-----------------------+------------------------+-----------------------+
|iso2_distinct|iso3_distinct|country_region_distinct|admin_region_1_distinct|iso_subdivision_distinct|admin_region_2_distinct|
+-------------+-------------+-----------------------+-----------------------+------------------------+-----------------------+
|           11|           11|                     11|                    108|                     108|                   2807|
+-------------+-------------+-----------------------+-----------------------+------------------------+-----------------------+



In [11]:
reduced_df.select("country_region").distinct().orderBy("country_region").show() #use select to return a dataframe

+-------------------+
|     country_region|
+-------------------+
|        Afghanistan|
|            Albania|
|            Algeria|
|     American Samoa|
|            Andorra|
|             Angola|
|           Anguilla|
|         Antarctica|
|Antigua and Barbuda|
|          Argentina|
|            Armenia|
|              Aruba|
|          Australia|
|            Austria|
|         Azerbaijan|
|            Bahamas|
|            Bahrain|
|         Bangladesh|
|           Barbados|
|            Belarus|
+-------------------+
only showing top 20 rows



In [12]:
covid_stats = ["id", "updated", "confirmed", "confirmed_change", "deaths", "deaths_change", "recovered", "recovered_change", "country_region"]

# Get the Hungarian covid statistics

In [13]:
reduced_df.select(covid_stats).where(condition=col("country_region")=="Hungary").show(5)

+-------+----------+---------+----------------+------+-------------+---------+----------------+--------------+
|     id|   updated|confirmed|confirmed_change|deaths|deaths_change|recovered|recovered_change|country_region|
+-------+----------+---------+----------------+------+-------------+---------+----------------+--------------+
|7169162|2020-03-04|        2|            NULL|     0|         NULL|     NULL|            NULL|       Hungary|
|7169165|2020-03-05|        3|               1|     0|            0|     NULL|            NULL|       Hungary|
|7169167|2020-03-06|        4|               1|     0|            0|        0|            NULL|       Hungary|
|7169169|2020-03-07|        7|               3|     0|            0|        0|               0|       Hungary|
|7169183|2020-03-08|        9|               2|     0|            0|        0|               0|       Hungary|
+-------+----------+---------+----------------+------+-------------+---------+----------------+--------------+
o

In [14]:
reduced_df.select(covid_stats).where((col("country_region") == "Hungary") & (reduced_df.confirmed > 50)).show(5) #wrap each logic condition seperately

+-------+----------+---------+----------------+------+-------------+---------+----------------+--------------+
|     id|   updated|confirmed|confirmed_change|deaths|deaths_change|recovered|recovered_change|country_region|
+-------+----------+---------+----------------+------+-------------+---------+----------------+--------------+
|7169200|2020-03-18|       58|               8|     1|            0|        2|               0|       Hungary|
|7169202|2020-03-19|       73|              15|     1|            0|        2|               0|       Hungary|
|7169204|2020-03-20|       85|              12|     4|            3|        7|               5|       Hungary|
|7169206|2020-03-21|      103|              18|     4|            0|        7|               0|       Hungary|
|7169208|2020-03-22|      131|              28|     6|            2|       16|               9|       Hungary|
+-------+----------+---------+----------------+------+-------------+---------+----------------+--------------+
o

In [15]:
from pyspark.sql.functions import count, min, max

max_cases = reduced_df.select(covid_stats).filter((col("country_region") == "Hungary")).agg(*[max(col) for col in covid_stats if col not in ["id", "updated", "country_region"]]).collect()[0]

In [16]:
max_cases

Row(max(confirmed)=2195926, max(confirmed_change)=20597, max(deaths)=48751, max(deaths_change)=834, max(recovered)=784539, max(recovered_change)=24851)

In [17]:
max_cases

Row(max(confirmed)=2195926, max(confirmed_change)=20597, max(deaths)=48751, max(deaths_change)=834, max(recovered)=784539, max(recovered_change)=24851)

In [18]:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead, sum, avg

In [19]:
type(max_cases)

pyspark.sql.types.Row

In [20]:
from pyspark.sql import Row
import datetime
import pandas as pd

hungarian_covid_stats_df = reduced_df.select(covid_stats).filter((col("country_region") == "Hungary"))

max_dict = {}
for column, i in zip(hungarian_covid_stats_df.columns, range(len(max_cases))):
    date_for_max_case = hungarian_covid_stats_df.select("updated").where(max_cases[i]==hungarian_covid_stats_df[2+i])
    if hungarian_covid_stats_df.columns[2+i] in ["confirmed_change", "deaths_change", "recovered_change"]:
        max_dict[hungarian_covid_stats_df.columns[2+i]] = date_for_max_case.collect()[0]["updated"]

pandas_df = pd.DataFrame([max_dict])
print(max_dict)
print(pandas_df)

{'confirmed_change': datetime.date(2021, 3, 31), 'deaths_change': datetime.date(2021, 4, 1), 'recovered_change': datetime.date(2021, 5, 13)}
  confirmed_change deaths_change recovered_change
0       2021-03-31    2021-04-01       2021-05-13


In [21]:
from pyspark.sql.functions import when

hungarian_covid_stats_extra = (
    hungarian_covid_stats_df
    .withColumn(
        "percentage_of_recovered", 
        when(col("confirmed") == 0, 0)
        .otherwise(col("recovered") / col("confirmed"))
    )
    .withColumn(
        "percentage_of_deaths", 
        when(col("confirmed") == 0, 0)  
        .otherwise(col("deaths") / col("confirmed"))
    )
    .withColumn(
        "active_cases", 
        when(col("confirmed") == 0, 0)
        .otherwise(col("confirmed") - col("recovered") - col("deaths"))
    )
)

date_window = Window.partitionBy("updated").orderBy("updated").rowsBetween(-7, 0) #assuming dates are unique
hungarian_covid_stats_extra = (hungarian_covid_stats_extra.withColumn("avg_last_7_active",avg("active_cases").over(date_window)))

date_window_lagging = Window.orderBy("updated") #no need for rows_between, or partition by, we apply globally
hungarian_covid_stats_extra = hungarian_covid_stats_extra.withColumn("pastweek_weekly_avg_active", lag("avg_last_7_active").over(date_window_lagging)) #could possibly change offset for lag

hungarian_covid_stats_extra = hungarian_covid_stats_extra.withColumn(
    "danger_zone",
    when(col("active_cases") > col("avg_last_7_active"), "danger")
    .when((col("active_cases") <= col("avg_last_7_active")) & (col("avg_last_7_active")*0.5 < col("active_cases")), "mild_danger")
    .otherwise("normal")
)

hungarian_covid_stats_extra = hungarian_covid_stats_extra.withColumn("weekly_active_cases_change", col("avg_last_7_active")-col("pastweek_weekly_avg_active"))

hungarian_covid_stats_extra.show(30, vertical=True)

-RECORD 0------------------------------------------
 id                         | 7169162              
 updated                    | 2020-03-04           
 confirmed                  | 2                    
 confirmed_change           | NULL                 
 deaths                     | 0                    
 deaths_change              | NULL                 
 recovered                  | NULL                 
 recovered_change           | NULL                 
 country_region             | Hungary              
 percentage_of_recovered    | NULL                 
 percentage_of_deaths       | 0.0                  
 active_cases               | NULL                 
 avg_last_7_active          | NULL                 
 pastweek_weekly_avg_active | NULL                 
 danger_zone                | normal               
 weekly_active_cases_change | NULL                 
-RECORD 1------------------------------------------
 id                         | 7169165              
 updated    

# Save the dataframe

In [22]:
#hungarian_covid_stats_extra.write.mode("overwrite").parquet("hungarian_covid_stats.parquet")

# Compute the aggregate statistics of "countries" with "united substring"

In [23]:
#https://stackoverflow.com/questions/38610559/convert-spark-dataframe-column-to-python-list
united_countries = (reduced_df.filter(col("country_region").contains("United"))
                         .select("country_region")
                         .distinct()
                         .toPandas()['country_region'].tolist())

In [24]:
united_countries

['United States', 'United Arab Emirates', 'United Kingdom']

In [25]:
united_df = reduced_df.select(covid_stats).where((col("country_region").isin(united_countries)) & (reduced_df.confirmed > 50)).groupby(col("country_region")).agg(sum(col="confirmed").alias("SumConfirmed")) #wrap each logic condition seperately

united_window = Window.orderBy(col("SumConfirmed").desc())  # no need to partition, we window globally

united_df = united_df.withColumn("dense_rank", dense_rank().over(united_window))

united_df.show()

+--------------------+------------+----------+
|      country_region|SumConfirmed|dense_rank|
+--------------------+------------+----------+
|       United States|140905721002|         1|
|      United Kingdom| 21159388426|         2|
|United Arab Emirates|   645314878|         3|
+--------------------+------------+----------+



In [26]:
reduced_df.join(policy_df, (reduced_df.country_region == policy_df.countryname) & (reduced_df.country_region == "Hungary"), how="inner").show(10, vertical=True)

-RECORD 0-----------------------------------------------------
 id                                    | 7169162              
 updated                               | 2020-03-04           
 confirmed                             | 2                    
 confirmed_change                      | NULL                 
 deaths                                | 0                    
 deaths_change                         | NULL                 
 recovered                             | NULL                 
 recovered_change                      | NULL                 
 latitude                              | 47.16519             
 longitude                             | 19.41204             
 iso2                                  | HU                   
 iso3                                  | HUN                  
 country_region                        | Hungary              
 admin_region_1                        | NULL                 
 iso_subdivision                       | NULL          

In [27]:
reduced_df.createOrReplaceTempView("covid_view")
policy_df.createOrReplaceTempView("policy_view")

#spark.sql("SELECT * FROM policy_view").show()
policy_df.show(vertical=True)

-RECORD 0-----------------------------------------------------
 countryname                           | Aruba                
 countrycode                           | ABW                  
 date                                  | 2020-01-01           
 c1_school_closing                     | 0.0                  
 c1_flag                               | NULL                 
 c2_workplace_closing                  | 0.0                  
 c2_flag                               | NULL                 
 c3_cancel_public_events               | 0.0                  
 c3_flag                               | NULL                 
 c4_restrictions_on_gatherings         | 0.0                  
 c4_flag                               | NULL                 
 c5_close_public_transport             | 0.0                  
 c5_flag                               | NULL                 
 c6_stay_at_home_requirements          | 0.0                  
 c6_flag                               | NULL          

In [28]:
policy_df.printSchema()

root
 |-- countryname: string (nullable = true)
 |-- countrycode: string (nullable = true)
 |-- date: date (nullable = true)
 |-- c1_school_closing: double (nullable = true)
 |-- c1_flag: boolean (nullable = true)
 |-- c2_workplace_closing: double (nullable = true)
 |-- c2_flag: boolean (nullable = true)
 |-- c3_cancel_public_events: double (nullable = true)
 |-- c3_flag: boolean (nullable = true)
 |-- c4_restrictions_on_gatherings: double (nullable = true)
 |-- c4_flag: boolean (nullable = true)
 |-- c5_close_public_transport: double (nullable = true)
 |-- c5_flag: boolean (nullable = true)
 |-- c6_stay_at_home_requirements: double (nullable = true)
 |-- c6_flag: boolean (nullable = true)
 |-- c7_restrictions_on_internal_movement: double (nullable = true)
 |-- c7_flag: boolean (nullable = true)
 |-- c8_international_travel_controls: double (nullable = true)
 |-- e1_income_support: double (nullable = true)
 |-- e1_flag: boolean (nullable = true)
 |-- e2_debt/contract_relief: double (nu

In [29]:
from pyspark.sql.functions import expr

In [30]:
(policy_df.select("countryname",
                 expr("e1_income_support != 0").alias("income_support"))
    .filter(expr("e1_income_support != 0"))
    .show())

+-----------+--------------+
|countryname|income_support|
+-----------+--------------+
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
|      Aruba|          true|
+-----------+--------------+
only showing top 20 rows



In [31]:
spark.sql("SELECT * FROM covid_view INNER JOIN policy_view ON covid_view.country_region == policy_view.countryname WHERE covid_view.country_region=='Hungary'").show(vertical=True)

-RECORD 0-----------------------------------------------------
 id                                    | 7169162              
 updated                               | 2020-03-04           
 confirmed                             | 2                    
 confirmed_change                      | NULL                 
 deaths                                | 0                    
 deaths_change                         | NULL                 
 recovered                             | NULL                 
 recovered_change                      | NULL                 
 latitude                              | 47.16519             
 longitude                             | 19.41204             
 iso2                                  | HU                   
 iso3                                  | HUN                  
 country_region                        | Hungary              
 admin_region_1                        | NULL                 
 iso_subdivision                       | NULL          

In [32]:
spark.sql("SELECT country_region, MAX(confirmed) AS CUMULATIVE_CONFIRMED FROM covid_view GROUP BY country_region ORDER BY cumulative_confirmed DESC").show()

+--------------+--------------------+
|country_region|CUMULATIVE_CONFIRMED|
+--------------+--------------------+
|     Worldwide|           675622359|
|      Bulgaria|           183974853|
| United States|           103582936|
|         India|            44688722|
|        France|            38591184|
|       Germany|            38202571|
|        Brazil|            37063464|
|         Japan|            33260228|
|   South Korea|            30555102|
|         Italy|            25603510|
|United Kingdom|            24396534|
|        Russia|            22342128|
|        Turkey|            17042722|
|         Spain|            13763336|
|       Vietnam|            11526950|
|     Australia|            11385534|
|        Taiwan|            10055439|
|     Argentina|            10044125|
|   Netherlands|             8607177|
|          Iran|             7569483|
+--------------+--------------------+
only showing top 20 rows

