In [1]:
import pyspark.sql.functions as fn
import pyspark.sql.types
from pyspark.sql import SparkSession
from pyspark.sql.types import *

In [2]:
# 1. Filters years from 2005-2020
# 2. Unbucketize the columns
# 3. Filters chosen countries
def filterCountryData(df,countries_chosen):
    
    # range b/w 2005-2020
    years = list(map(lambda x: str(x),list(range(2005,2021,1)))) 
    
    cols =["Country Name","Country Code","Indicator Name","Indicator Code"]+years
    country_2005_20 = df.select(cols)
    
    
    # filters countries chosen and fills any missing year values with 0.00
    ts = "2020-04-01"
    countries_chosen_2005_20 = (country_2005_20
                                .filter(fn.col("Country Name").isin(countries_chosen))
                                .withColumn("date",fn.date_format(fn.lit(ts),"yyyy-MM-dd"))
                               )
    
    #unbucketize the data
    unpivotStr= list(map(lambda x: " '{t}',`{t}`".format(t=x),years))
    sep = ','
    unpivotExpr = "stack("+str(len(years))+", "+sep.join(unpivotStr)+") as (Year, Value)"
    columns_without_years= set(countries_chosen_2005_20.columns ) - set(years)
    
    res = countries_chosen_2005_20.select(
        "Country Name",
        "Country Code",
        "Indicator Name",
        fn.expr(unpivotExpr),
        fn.month("date").alias("month"),
        fn.dayofmonth("date").alias("day"),
        fn.quarter("date").alias("quarter")            
    ).groupBy("Country Name","Year").pivot("Indicator Name").sum("Value")
    
    #TODO: join the dimensions to make a fact table

    return res

In [3]:
# Date Dimension
def generate_dates(spark,range_list,interval=60*60*24,dt_col="date_time_ref"): # TODO: attention to sparkSession
     """
     Create a Spark DataFrame with a single column named dt_col and a range of date within a specified interval (start and stop included).
     With hourly data, dates end at 23 of stop day

     :param spark: SparkSession or sqlContext depending on environment (server vs local)
     :param range_list: array of strings formatted as "2018-01-20" or "2018-01-20 00:00:00"
     :param interval: number of seconds (frequency), output from get_freq()
     :param dt_col: string with date column name. Date column must be TimestampType

     :returns: df from range
     """
     start,stop = range_list
     temp_df = spark.createDataFrame([(start, stop)], ("start", "stop"))
     temp_df = temp_df.select([fn.col(c).cast("timestamp") for c in ("start", "stop")])
     temp_df = temp_df.withColumn("stop",fn.date_add("stop",1).cast("timestamp"))
     temp_df = temp_df.select([fn.col(c).cast("long") for c in ("start", "stop")])
     start, stop = temp_df.first()
     return spark.range(start,stop,interval).select(fn.col("id").cast("timestamp").alias(dt_col))


def dateDimension():
    time_rng = ["2005-01-01","2020-12-31"]
    year_df= generate_dates(spark,time_rng)
    tmp = (year_df
           .withColumn("year",fn.year("date_time_ref"))
           .withColumn("month",fn.month("date_time_ref"))
           .withColumn("day",fn.dayofmonth("date_time_ref"))
           .withColumn("quarter",fn.quarter("date_time_ref"))
           .withColumn("decade",
                          fn.when(fn.col("year") % 10 >=5,fn.col("year")-fn.col("year")%10+10)
                              .otherwise(fn.col("year")- fn.col("year") % 10))
           .withColumn("year_code",fn.monotonically_increasing_id())

          )
    date_dim = (tmp
                   .select(tmp.year_code,*set(tmp.columns)-set(["year_code"]))
               )
    
    return date_dim

In [4]:
def naturalDisasterDim(df,filePath,countries_chosen):
    """
        creates natural disaster dimension + look up table
    
        df - date dataframe
        filePath - filePath to natural disaster csv
        countries_chosen - list of strings of countries to work on
    """
    
    columns = ["total deaths","Total Damages ('000 US$)"]
    
    # reads csv
    natural_disaster_df = (spark
                       .read
                       .format('csv')
                           .option("inferSchema",True)
                           .option("header",True)
                           .load(filePath)
                           .fillna(0.00,subset=columns)).dropDuplicates()
    

    # reconfigures column names + banding
    tmp_nd = (natural_disaster_df
                  # replaces United States of America -> united states
              .withColumn("Country",fn.when(fn.lower(fn.col("Country")).contains("united states"),"united states").otherwise(fn.lower(fn.col("Country"))))
              .withColumn("start_month",fn.col("Start Month"))
                  .withColumn("start_year",fn.col("Start Year"))
                  .withColumn("start_day",fn.col("Start Day"))
                  .withColumn("end_month",fn.col("End Month"))
                  .withColumn("end_year",fn.col("End Year"))
                  .withColumn("end_day",fn.col("End Day"))
              .withColumn("disaster_type",fn.col("Disaster Type"))
              .withColumn("disaster_subtype",fn.col("disaster subtype"))
              .withColumn("disaster_nestedsubtype",fn.col("disaster subsubtype"))
              .withColumn("disaster_subgroup",fn.col("disaster subgroup"))
              .withColumn("event_name",fn.col("event name"))
              .withColumn("ofda_response",fn.col("ofda response"))
              .fillna(1.0,["start_day","start_month","start_year","end_day","end_month","end_year"])
              # TODO figure out what to do about start and end year
              .fillna("Not Available",["disaster_type","disaster_subtype","disaster_nestedsubtype","disaster_subgroup","event_name","ofda_response"])
              .withColumn("ttl_death",
                          # range (low,medium, high)
                          fn.when(fn.col("total deaths")>7000,
                                  fn.when(fn.col("total deaths")>14000,"high").otherwise("medium")).otherwise("low")
                         )
              .withColumn("ttl_damages",
                          # 
                          fn.when(fn.col("Total Damages ('000 US$)")>1000000,
                                  fn.when(fn.col("Total Damages ('000 US$)")>100000000,"high").otherwise("medium")).otherwise("low")
                         )
              
                  .drop("year")
             )

    # join on start year
    max_year = df.select(fn.max("year")).limit(1).collect()[0][0]
    min_year = df.select(fn.min("year")).limit(1).collect()[0][0]
    
    nd_j_on_date = tmp_nd.filter(fn.col("start_year")>=min_year).filter(fn.col("end_year")<=max_year)

    # filter countries chosen
    filtered_byCountry_date = (nd_j_on_date
           .filter(fn.col("Country").isin(list(map(lambda x: x.lower(),countries_chosen))))
           
    )
    
    # distinct banded rows with key
    res = (filtered_byCountry_date  
                                  .select([
                                           "disaster_type",
                                           "disaster_subtype",
                                           "disaster_nestedsubtype",
                                           "disaster_subgroup",
                                           "event_name",
                                           "ttl_death",
                                           "ttl_damages",
                                           "ofda_response"])                                   
                                 ).distinct().withColumn("natural_disaster_key",fn.monotonically_increasing_id())
    
    
    lookup=(res.join(
        filtered_byCountry_date,
        on = [
            "disaster_type","disaster_subtype","disaster_nestedsubtype","disaster_subgroup","event_name","ofda_response","ttl_damages","ttl_death"
        ])
        .select("natural_disaster_key","Country","start_year","start_month","start_day","end_year","end_month","end_day")
        )
    
    
    # dimension, lookup
    return res,lookup

dateDim = dateDimension()
df, tmp=naturalDisasterDim(
    dateDim,
    countries_chosen=countries_chosen,
    filePath="AssignmentData/ExternalSources/DISASTERS/1900_2021_DISASTERS.xlsx - emdat data.csv"
)

display(df.toPandas())
display(tmp.toPandas())

NameError: name 'spark' is not defined

In [None]:
def countryDimension(time_df,indicators,countries_chosen,filePath):
    max_year = time_df.select(fn.max("year")).limit(1).collect()[0][0]
    min_year = time_df.select(fn.min("year")).limit(1).collect()[0][0]
    
    countries = (spark
                       .read
                       .format('csv')
                           .option("inferSchema",True)
                           .option("header",True)
                           .load(filePath)
                        )

    indicators = (indicators
                  .withColumn("country_name",fn.lower("Country Name"))
                  .drop("Country Name")
                  .withColumn("age_dependency_ratio_workingage",
                              fn.when(fn.col("Age dependency ratio (% of working-age population)")>100.00, 100.00)
                              .otherwise(fn.col("Age dependency ratio (% of working-age population)")),)
                  .withColumn("labor_force_total",
                              fn.when(fn.col("Labor force, total")>30000000,
                                     fn.when(fn.col("Labor force, total")>80000000,"high").otherwise("medium")
                                     ).otherwise("low"))
                  .select(
                      fn.col("country_name"),
                      fn.col("Population, total").alias("population_total"),
                      fn.col("Population growth (annual %)").alias("population_growth"),
                      fn.col("Urban population growth (annual %)").alias("urban_population_growth"),
                      fn.col("Urban population").alias("urban_population"),
                      fn.col("Rural population").alias("rural_population"),
                      fn.col("Unemployment, total (% of total labor force)").alias("unemployment_rate"),
                      fn.col("age_dependency_ratio_workingage"),
                      fn.col("Poverty headcount ratio at national poverty line (% of population)").alias("poverty_headcount_percentage"),
                      fn.col("labor_force_total"),
                      fn.col("Net migration").alias("net_migration"),
                      fn.col("year")
                  )
#                   .fillna(
#                       indicators.select(fn.avg("Age dependency ratio (% of working-age population)")).collect()[0][0],
#                       subset=["age_dependency_ratio_workingage"]
#                   )
                 )
    tmp = (countries
               .filter(fn.lower(fn.col("short name")).isin(list(map(lambda x: x.lower(),countries_chosen))))
               .select(
                   fn.lower("Currency Unit").alias("currency"),
                   fn.lower("short name").alias("country_name"),
                   fn.col("region"),
               )
          )
    
    res = (tmp.join(indicators,on=["country_name"]).withColumn("country_key",fn.monotonically_increasing_id()))
    
    lookup = (res.select(
        "country_name",
        "year",
        "country_key"
    ))

           
    return res.drop("year"), lookup
               
dateDim = dateDimension()
df,tmp = countryDimension(  
    dateDim,
    filterdCountryDf,
    countries_chosen=countries_chosen,
    filePath="AssignmentData/HNP_StatsCountry.csv")

display(df.toPandas())
display(tmp.toPandas())

summary_df(filterdCountryDf,"Age dependency ratio (% of working-age population)",bns=20)


In [None]:
def educationDimension(time_df,indicators,countries_chosen,filePath):
    #NOT SURE WHAT THIS DOES, PROBABLY CAN REMOVE
    max_year = time_df.select(fn.max("year")).limit(1).collect()[0][0]
    min_year = time_df.select(fn.min("year")).limit(1).collect()[0][0]
    
    #Read the file
    countries = (spark
                       .read
                       .format('csv')
                           .option("inferSchema",True)
                           .option("header",True)
                           .load(filePath)
                        )
    #Get the indicators for this dimension
    #In this case:
    """
        "Primary completion rate, total (% of relevant age group)",
        "School enrollment, primary (% net)",
        "School enrollment, secondary (% net)",
        "School enrollment, tertiary (% gross)",
                    
                    
        Bottom three suck (100+ zeros) or doesnt exist so not added
        "Literacy rate, adult total (% of people ages 15 and above)",
        "Literacy rate, youth total (% of people ages 15-24)"
         Government expenditure on education, total (% of GDP)
    """
    
    
    indicators = (indicators
                  .withColumn("country_name",fn.lower("Country Name"))
                  .drop("Country Name")
                  #Preprocess the columns that have bad data
                  .withColumn("primary_completion_rate",
                              fn.when(fn.col("Primary completion rate, total (% of relevant age group)")>100.00,100.00)
                              .otherwise(fn.col("Primary completion rate, total (% of relevant age group)")),)
                  .select(
                      #Add the actual data to the table
                      fn.col("country_name"),
                      fn.col("primary_completion_rate"),
                      fn.col("School enrollment, primary (% net)").alias("school_enrollment_primary_%_net"),
                      fn.col("School enrollment, secondary (% net)").alias("school_enrollment_secondary_%_net"),
                      fn.col("School enrollment, tertiary (% gross)").alias("school_enrollment_tertiary_%_gross"),
                      
                      fn.col("year")
                  )
                 )
    #Select the countries we care about
    tmp = (countries
           .filter(fn.lower(fn.col("short name")).isin(list(map(lambda x: x.lower(),countries_chosen))))
            .select(
                fn.lower("short name").alias("country_name"),
               )
    )
    
    #Join the indicators table and the res table to get the indicators from the countries we want
    #Also add the key
    res = (tmp.join(indicators,on=["country_name"]).distinct().withColumn("education_key",fn.monotonically_increasing_id()))
    
    #Set up the look up table
    lookup = (res.select(
        "country_name",
        "year",
        "education_key"
    ))
    
    return res.drop("year"), lookup

#TEST CODE REMOVE LATER
dateDim = dateDimension()
df,tmp = educationDimension(  
    dateDim,
    filterdCountryDf,
    countries_chosen=countries_chosen,
    filePath="AssignmentData/HNP_StatsCountry.csv")

display(df.toPandas())
display(tmp.toPandas())

In [12]:
def generalHealthDimension(time_df,indicators,countries_chosen,filePath):
    
    #Read the file
    countries = (spark
                       .read
                       .format('csv')
                           .option("inferSchema",True)
                           .option("header",True)
                           .load(filePath)
                        )
    
    # Replace null values with mean for the following columns
    
    list_of_attributes = ["Prevalence of undernourishment (% of population)", "Life expectancy at birth, total (years)",
     "Mortality rate, under-5 (per 1,000)", "Number of infant deaths", "Number of deaths ages 10-14 years",
    "Number of deaths ages 15-19 years", "Number of deaths ages 20-24 years"]

    indicators = replaceNullWithMean(indicators, list_of_attributes)
    
    #Get the indicators for this dimension
    #In this case:
    """
        mortality_rate_under_5
        life_expectancy_at_birth
        number_of_infant_deaths
        number_of_deaths_ages_10-14
        number_of_deaths_ages_15-19
        number_of_deaths_age_20_24
        percent_of_population_undernourished
    """
    
    indicators = (indicators
                  .withColumn("country_name",fn.lower("Country Name"))
                  .drop("Country Name")
                  .withColumn("mortality_rate_under_5",
                              fn.when(fn.col("Mortality rate, under-5 (per 1,000)")>50,
                                     fn.when(fn.col("Mortality rate, under-5 (per 1,000)")>100,"high").otherwise("medium")
                                     ).otherwise("low"))
                  .select(
                      #Add the actual data to the table
                      fn.col("country_name"),
                      fn.col("year"),
                      fn.col("Prevalence of undernourishment (% of population)").alias("percent_of_population_undernourished"),
                      fn.col("Life expectancy at birth, total (years)").alias("life_expectancy_at_birth"),
                      fn.col("mortality_rate_under_5"),
                      fn.col("Number of infant deaths").alias("number_of_infant_deaths"),
                      fn.col("Number of deaths ages 10-14 years").alias("number_of_deaths_ages_10-14"),
                      fn.col("Number of deaths ages 15-19 years").alias("number_of_deaths_ages_15-19"),
                      fn.col("Number of deaths ages 20-24 years").alias("number_of_deaths_ages_20-24")
                  )
                 )
    #Select the countries we care about
    tmp = (countries
           .filter(fn.lower(fn.col("short name")).isin(list(map(lambda x: x.lower(),countries_chosen))))
            .select(
                fn.lower("short name").alias("country_name"),
               )
    )
    
    #Join the indicators table and the res table to get the indicators from the countries we want
    #Also add the key
    res = (tmp.join(indicators,on=["country_name"]).distinct().withColumn("general_health_key",fn.monotonically_increasing_id()))
    
    #Set up the look up table
    lookup = (res.select(
        "country_name",
        "year",
        "general_health_key"
    ))
    
    return res.drop("year"), lookup

#Initializing data for this dimension
countries_chosen = ["United States", "Canada","Mexico","Thailand","China","India","Niger","Madagascar","Guinea"]
df=spark.read.format("csv").option("header",True).option("inferSchema",True).load("AssignmentData/HNP_StatsData.csv")
filterdCountryDf=filterCountryData(df,countries_chosen)
dateDim = dateDimension()

#Create the general health dimension
df,tmp = generalHealthDimension(  
    dateDim,
    filterdCountryDf,
    countries_chosen=countries_chosen,
    filePath="AssignmentData/HNP_StatsCountry.csv")

display(df.toPandas())
display(tmp.toPandas())



                                                                                

Unnamed: 0,country_name,percent_of_population_undernourished,life_expectancy_at_birth,mortality_rate_under_5,number_of_infant_deaths,number_of_deaths_ages_10-14,number_of_deaths_ages_15-19,number_of_deaths_ages_20-24,general_health_key
0,canada,2.500000,81.900000,low,1713.0,221.0,712.0,1392.0,0
1,madagascar,30.400000,62.509000,medium,35599.0,4449.0,5945.0,4791.0,1
2,china,2.500000,76.912000,low,112595.0,16111.0,22316.0,44112.0,2
3,madagascar,33.400000,61.212000,medium,38158.0,4235.0,5205.0,4530.0,3
4,niger,10.478095,53.411000,high,54103.0,5993.0,4032.0,4295.0,4
...,...,...,...,...,...,...,...,...,...
139,china,2.500000,76.704000,low,123080.0,16403.0,23259.0,45199.0,139
140,niger,10.478095,54.180000,high,53178.0,6228.0,4115.0,4409.0,140
141,thailand,10.400000,73.766000,low,9848.0,2536.0,6370.0,7131.0,141
142,united states,2.500000,77.487805,low,27770.0,3815.0,14037.0,20209.0,142


Unnamed: 0,country_name,year,general_health_key
0,canada,2017,0
1,madagascar,2008,1
2,china,2019,2
3,madagascar,2005,3
4,niger,2005,4
...,...,...,...
139,china,2018,139
140,niger,2006,140
141,thailand,2009,141
142,united states,2005,142


In [None]:
def medicalCapabilityDimension(time_df,indicators,countries_chosen,filePath):
    #NOT SURE WHAT THIS DOES, PROBABLY CAN REMOVE
    max_year = time_df.select(fn.max("year")).limit(1).collect()[0][0]
    min_year = time_df.select(fn.min("year")).limit(1).collect()[0][0]
    
    #Read the file
    countries = (spark
                       .read
                       .format('csv')
                           .option("inferSchema",True)
                           .option("header",True)
                           .load(filePath)
                        )
    #Get the indicators for this dimension
    #In this case:
    """
    "Domestic general government health expenditure per capita (current US$)", sux
    "Mortality rate attributed to unsafe water, unsafe sanitation and lack of hygiene (per 100,000 population)", sux
    "Hospital beds (per 1,000 people)", aight
    
    "Community health workers (per 1,000 people)", sux
    "People using safely managed drinking water services (% of population)", aight
    "People with basic handwashing facilities including soap and water (% of population)", aight
    "People using safely managed sanitation services (% of population)" goated
    "Physicians (per 1,000 people)" sux
    """
    
    
    indicators = (indicators
                  .withColumn("country_name",fn.lower("Country Name"))
                  .drop("Country Name")
                  #Preprocess the columns that have bad data
                  #Use .withColumn
                  .select(
                      #Add the actual data to the table
                      fn.col("country_name"),
                      fn.col("Hospital beds (per 1,000 people)").alias("num_hospital_beds"),
                      fn.col("People with basic handwashing facilities including soap and water (% of population)").alias("people_with_basic_handwashing_facilities"),
                      fn.col("People using safely managed sanitation services (% of population)").alias("people_using_safely_managed_sanitation"),
                      
                      fn.col("year")
                  )
                 )
    #Select the countries we care about
    tmp = (countries
           .filter(fn.lower(fn.col("short name")).isin(list(map(lambda x: x.lower(),countries_chosen))))
            .select(
                fn.lower("short name").alias("country_name"),
               )
    )
    
    #Join the indicators table and the res table to get the indicators from the countries we want
    #Also add the key
    res = (tmp.join(indicators,on=["country_name"]).withColumn("medical_capability_key",fn.monotonically_increasing_id()))
    
    #Set up the look up table
    lookup = (res.select(
        "country_name",
        "year",
        "medical_capability_key"
    ))
    
    return res.drop("year"), lookup

#TEST CODE REMOVE LATER
dateDim = dateDimension()
df,tmp = medicalCapabilityDimension(  
    dateDim,
    filterdCountryDf,
    countries_chosen=countries_chosen,
    filePath="AssignmentData/HNP_StatsCountry.csv")

display(df.toPandas())
display(tmp.toPandas())

In [None]:
def immunizationDimension(time_df,indicators,countries_chosen,filePath):
    #NOT SURE WHAT THIS DOES, PROBABLY CAN REMOVE
    max_year = time_df.select(fn.max("year")).limit(1).collect()[0][0]
    min_year = time_df.select(fn.min("year")).limit(1).collect()[0][0]
    
    #Read the file
    countries = (spark
                       .read
                       .format('csv')
                           .option("inferSchema",True)
                           .option("header",True)
                           .load(filePath)
                        )
    #Get the indicators for this dimension
    #In this case:
    """
    "Children (0-14) living with HIV", aight
    "Immunization, HepB3 (% of one-year-old children)", goated
    "Immunization, BCG (% of one-year-old children)", goated
    "Immunization, Hib3 (% of children ages 12-23 months)", aight
    "Immunization, DPT (% of children ages 12-23 months)" goated
    """
    
    
    indicators = (indicators
                  .withColumn("country_name",fn.lower("Country Name"))
                  .drop("Country Name")
                  #Preprocess the columns that have bad data
                  #Use .withColumn
                  .select(
                      #Add the actual data to the table
                      fn.col("country_name"),
                      fn.col("Children (0-14) living with HIV").alias("children_living_w_hiv"),
                      fn.col("Immunization, HepB3 (% of one-year-old children)").alias("hepb3_immunization_rate"),
                      fn.col("Immunization, BCG (% of one-year-old children)").alias("bcg_immunization_rate"),
                      fn.col("Immunization, Hib3 (% of children ages 12-23 months)").alias("hib3_immunization_rate"),
                      fn.col("Immunization, DPT (% of children ages 12-23 months)").alias("dpt_immunization_rate"),
                      
                      fn.col("year")
                  )
                 )
    #Select the countries we care about
    tmp = (countries
           .filter(fn.lower(fn.col("short name")).isin(list(map(lambda x: x.lower(),countries_chosen))))
            .select(
                fn.lower("short name").alias("country_name"),
               )
    )
    
    #Join the indicators table and the res table to get the indicators from the countries we want
    #Also add the key
    res = (tmp.join(indicators,on=["country_name"]).withColumn("immunization_key",fn.monotonically_increasing_id()))
    
    #Set up the look up table
    lookup = (res.select(
        "country_name",
        "year",
        "immunization_key"
    ))
    
    return res.drop("year"), lookup

#TEST CODE REMOVE LATER
dateDim = dateDimension()
df,tmp = immunizationDimension(  
    dateDim,
    filterdCountryDf,
    countries_chosen=countries_chosen,
    filePath="AssignmentData/HNP_StatsCountry.csv")

display(df.toPandas())
display(tmp.toPandas())

In [5]:
spark = SparkSession.builder.appName("ds_datastage").getOrCreate()

22/03/16 23:03:29 WARN Utils: Your hostname, lgcypher-Inspiron-13-5378 resolves to a loopback address: 127.0.1.1; using 192.168.0.71 instead (on interface wlp1s0)
22/03/16 23:03:29 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/03/16 23:03:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [None]:
#MAIN block
countries_chosen = ["United States", "Canada","Mexico","Thailand","China","India","Niger","Madagascar","Guinea"]

df=spark.read.format("csv").option("header",True).option("inferSchema",True).load("AssignmentData/HNP_StatsData.csv")

#filtered data
filterdCountryDf=filterCountryData(df,countries_chosen)

dateDim = dateDimension()
generalHealthDimension, nd_lookup=generalHealthDim(
    dateDim,
    countries_chosen=countries_chosen,
    filePath="
)

# countryDimension = 
# display(naturalDisasterDimension.toPandas())
# display(nd_lookup.toPandas())
display(filterdCountryDf.toPandas())
# display(dateDim.toPandas())
filterdCountryDf.count()

In [None]:

### TESTING BLOCK


# ## LOOKUP TABLE LOGIC
# # 2006-2010
# tmp = dateDim.filter(fn.col("year")==2006).select(fn.col("year").alias("year_2"))
# dateDim_a = dateDim.alias("a")
# tmp_b = tmp.alias("b")

# test2 = dateDim_a.join(tmp_b.alias("b"),tmp_b.year_2<dateDim_a.year)
# test3 = dateDim_a.join(tmp_b.alias("b"),2010>dateDim_a.year)

# test2.show()
# test3.show()

# test3.intersect(test2).show()


In [10]:
### TEST FUNCTIONS

import matplotlib.pyplot as plt 
import numpy as np

# df - dataframe
# cls - columns to replace null with medians
# returns dataframe with the corresponding columns null values as median
def replaceNullWithMean(df, cls):
    for column in cls:
        mean = df.filter(fn.col(column).isNotNull()).agg(fn.mean(column).alias("mean"))
        df = df.fillna(mean.collect()[0][0], column)
    return df

# df - dataframe
# col - column to observe
def nullCount(df,cl):
    non_null =(df
     .filter(fn.col(cl).isNotNull())
    )
    
    null = (df
        .filter(fn.col(cl).isNull()))
    
    print("Number of non null values: "+str(non_null.count()))
    print("Number of null values: "+str(null.count()))
    
    
def summary_df(df,cl,bns = 10):
    """
        returns null counts, basic statistics & plot of current values in a column
        
        df - dataframe you wish to observer these statistics
        cl - column of which you wish to observe
        bns - bins (number of bars) histogram will try to bucketize data in
    """
    nullCount(df,cl)
    
#     df.groupBy(fn.col(cl)).count().orderBy(fn.asc(fn.col(cl))).show()
#     df.groupBy(fn.col(cl)).count().orderBy(fn.desc(fn.col(cl))).show()
    
    tmp = df.filter(fn.col(cl).isNotNull())
    tmp.select(cl).describe().show()

    pd_data = tmp.select(fn.col(cl)).toPandas()
    # display(pd_data)
    plt.hist(pd_data,bins = bns)
    plt.title("Histogram of " +str(cl))
    plt.xlabel(cl)
    plt.ylabel("count")

In [None]:
#Dataframe summaries here: just change the middle parameter with the indicator of choice, change bns if you want
summary_df(filterdCountryDf,"",bns=20)

In [None]:
filterdCountryDf.columns

In [None]:
cl = 'Primary completion rate, total (% of relevant age group)'
first = filterdCountryDf.filter(fn.col(cl).isNotNull()).groupBy("Country Name").agg(fn.count("*").alias("nonnull"))
second = filterdCountryDf.filter(fn.col(cl).isNull()).groupBy("Country Name").agg(fn.count("*").alias("null"))

In [None]:
second.join(first,["Country Name"],'right').withColumn("ratio",fn.col("nonnull")/fn.col("null")).orderBy(fn.desc("ratio")).show(300,False)