In [0]:
%run ./loadData

In [0]:
from pyspark.sql.types import BooleanType, DoubleType, IntegerType, StringType, StructField, StructType
from pyspark.sql.functions import expr, lit, regexp_replace, when
dbutils.fs.ls("/mnt/worldbankdata")

In [0]:
# Get the Life Expectancy CSV and show some sample data.
dfLifeExp = spark.read.option("header", "true").csv(lifeExpCsv)
dfLifeExp.select(dfLifeExp.columns[0:5]).show()

In [0]:
# Create the first dimension table for countries.
# Some of the entries are regions rather than countries - add 2 columns to table to reflect this.
from pyspark.sql.functions import lit
dimCountrySchema = StructType([ \
                               StructField("regionCode", StringType(), True), \
                               StructField("countryName", StringType(), True), \
                               StructField("isCountry", BooleanType(), True), \
                               StructField("isRegion", BooleanType(), True), \
                              ])
countriesSet = {"ABW", "AFG", "AGO", "ALB", "AND", "ARE", "ARG", "ARM", "ASM", "ATG",
                "AUS", "AUT", "AZE", "BDI", "BEL", "BEN", "BFA", "BGD", "BGR", "BHR",
                "BHS", "BIH", "BLR", "BLZ", "BMU", "BOL", "BRA", "BRB", "BRN", "BTN", 
                "BWA", "CAF", "CAN", "CHE", "CHI", "CHL", "CHN", "CIV", "CMR", "COD",
                "COG", "COL", "COM", "CPV", "CRI", "CUB", "CUW", "CYM", "CYP", "CZE",
                "DEU", "DJI", "DMA", "DNK", "DOM", "DZA", "ECU", "EGY", "ERI", "ESP",
                "EST", "ETH", "FIN", "FJI", "FRA", "FRO", "FSM", "GAB", "GBR", "GEO",
                "GHA", "GIB", "GIN", "GMB", "GNB", "GNQ", "GRC", "GRD", "GRL", "GTM",
                "GUM", "GUY", "HKG", "HND", "HRV", "HTI", "HUN", "IDN", "IMN", "IND",
                "IRL", "IRN", "IRQ", "ISL", "ISR", "ITA", "JAM", "JOR", "JPN", "KAZ",
                "KEN", "KGZ", "KHM", "KIR", "KNA", "KOR", "KWT", "LAO", "LBN", "LBR",
                "LBY", "LCA", "LIE", "LKA", "LSO", "LTU", "LUX", "LVA", "MAC", "MAF",
                "MAR", "MCO", "MDA", "MDG", "MDV", "MEX", "MHL", "MKD", "MLI", "MLT",
                "MMR", "MNE", "MNG", "MNP", "MOZ", "MRT", "MUS", "MWI", "MYS", "NAM",
                "NCL", "NER", "NGA", "NIC", "NLD", "NOR", "NPL", "NRU", "NZL", "OMN",
                "PAK", "PAN", "PER", "PHL", "PLW", "PNG", "POL", "PRI", "PRK", "PRT",
                "PRY", "PSE", "PYF", "QAT", "ROU", "RUS", "RWA", "SAU", "SDN", "SEN",
                "SGP", "SLB", "SLE", "SLV", "SMR", "SOM", "SRB", "SSD", "STP", "SUR",
                "SVN", "SVN", "SWE", "SWZ", "SXM", "SYC", "SYC", "TCA", "TCD", "TGO",
                "THA", "TJK", "TKM", "TLS", "TON", "TTO", "TUN", "TUR", "TUV", "TZA",
                "UGA", "UKR", "URY", "USA", "UZB", "VCT", "VEN", "VGB", "VIR", "VNM",
                "VUT", "WSM", "XKX", "YEM", "ZAF", "ZMB", "ZWE"}

regionSet = {"AFE", "AFW", "ARB", "CEB", "CSS", "EAS", "ECS", "EMU", "EUU", "LCN",
             "MEA", "NAC", "SAS", "SSF"}

dimCountry = dfLifeExp.select(col("Country Code").alias("regionCode"), col("Country Name").alias("countryName"))
dimCountry = dimCountry.withColumn("isCountry", when(dimCountry.regionCode.isin(countriesSet), True).otherwise(None)).withColumn("isRegion", when(dimCountry.regionCode.isin(regionSet), True).otherwise(None))
display(dimCountry)

regionCode,countryName,isCountry,isRegion
ABW,Aruba,True,
AFE,Africa Eastern and Southern,,True
AFG,Afghanistan,True,
AFW,Africa Western and Central,,True
AGO,Angola,True,
ALB,Albania,True,
AND,Andorra,True,
ARB,Arab World,,True
ARE,United Arab Emirates,True,
ARG,Argentina,True,


In [0]:
# Create the second dimension table for indicator codes.
dimIndicatorSchema = StructType([ \
                                 StructField("indicatorCode", StringType(), True), \
                                 StructField("indicatorName", StringType(), True), \
                                ])
dimIndicatorData = []

for csv in csvList:
    df = spark.read.option("header", "true").csv(csv)
    indicatorCode = df.select(col("Indicator Code")).collect()[0][0]
    indicatorName = df.select(col("Indicator Name")).collect()[0][0]
    tuple = (indicatorCode, indicatorName)
    dimIndicatorData.append(tuple)

dimIndicator = spark.createDataFrame(data=dimIndicatorData, schema=dimIndicatorSchema)
display(dimIndicator)

indicatorCode,indicatorName
EG.ELC.ACCS.ZS,Access to electricity (% of population)
AG.LND.AGRI.ZS,Agricultural land (% of land area)
NV.AGR.TOTL.ZS,"Agriculture, forestry, and fishing, value added (% of GDP)"
AG.LND.ARBL.ZS,Arable land (% of land area)
EN.ATM.CO2E.PC,CO2 emissions (metric tons per capita)
SH.STA.DIAB.ZS,Diabetes prevalence (% of population ages 20 to 79)
EG.USE.ELEC.KH.PC,Electric power consumption (kWh per capita)
SP.DYN.TFRT.IN,"Fertility rate, total (births per woman)"
AG.LND.FRST.ZS,Forest area (% of land area)
NY.GDP.MKTP.KD.ZG,GDP growth (annual %)


In [0]:
# Create the third dimension table for dates
dimDateSchema = StructType([ \
                            StructField("dateYear", IntegerType(), True), \
                            StructField("dateDecade", StringType(), True), \
                           ])
dimDateData = []

for year in range(1960, 2021):
    if 1960 <= year < 1970:
        tuple = (year, "1960s")
    elif 1970 <= year < 1980:
        tuple = (year, "1970s")
    elif 1980 <= year < 1990:
        tuple = (year, "1980s")
    elif 1990 <= year < 2000:
        tuple = (year, "1990s")
    elif 2000 <= year < 2010:
        tuple = (year, "2000s")
    elif 2010 <= year < 2020:
        tuple = (year, "2010s")
    else:
        tuple = (year, "2020s")
    dimDateData.append(tuple)

dimDate = spark.createDataFrame(data=dimDateData, schema=dimDateSchema)
display(dimDate)

dateYear,dateDecade
1960,1960s
1961,1960s
1962,1960s
1963,1960s
1964,1960s
1965,1960s
1966,1960s
1967,1960s
1968,1960s
1969,1960s


In [0]:
# Next, create the fact table.
from pyspark.sql.types import BooleanType, DoubleType, IntegerType, StringType, StructField, StructType
factSchema = StructType([ \
                            StructField("regionCode", StringType(), True), \
                            StructField("indicatorCode", StringType(), True), \
                            StructField("dateYear", IntegerType(), True), \
                            StructField("value", DoubleType(), True), \
                           ])
factData = []

for csv in csvList:
    df = spark.read.option("header", "true").csv(csv)
    numCodes = df.select(col("Country Code")).count()
    indicatorCode = df.select(col("Indicator Code")).collect()[0][0]
    regionColumn = df.select(col("Country Code")).collect()
    for year in range(1960, 2021):
        yearStr = str(year)
        yearColumn = df.select(col(yearStr)).collect()
        for counter in range(0, numCodes):
            regionCode = regionColumn[counter].asDict()['Country Code']
            value = yearColumn[counter].asDict()[yearStr]
            if value is not None:
                valueFloat = float(value)
                tuple = (regionCode, indicatorCode, year, valueFloat)
            else:
                tuple = (regionCode, indicatorCode, year, value)
            factData.append(tuple)

factTable = spark.createDataFrame(data=factData, schema=factSchema)
factTable.show()

In [0]:
# Try creating the fact table by unpivoting. to see if we get better performance
factSchema = StructType([ \
                            StructField("regionCode", StringType(), True), \
                            StructField("indicatorCode", StringType(), True), \
                            StructField("dateYear", IntegerType(), True), \
                            StructField("value", DoubleType(), True), \
                           ])

factTable2 = spark.createDataFrame(data=[], schema=factSchema)
for csv in csvList:
    ingestedDf = spark.read.option("header", "true").csv(csv)
    factDf = ingestedDf.select(col("Country Code"), col("Indicator Code"), lit(ingestedDf.columns[4]).alias("dateYear"),ingestedDf[ingestedDf.columns[4]].alias("value"))
    for counter in ingestedDf.columns[5:]:
        factDf = factDf.union(ingestedDf.select(col("Country Code"), col("Indicator Code"), lit(counter), ingestedDf[counter]))
    factDf = factDf.withColumnRenamed("Country Code", "regionCode").withColumnRenamed("Indicator Code", "indicatorCode")
    factTable2 = factTable.union(factDf)

factTable2 = factTable2.withColumn("value", col("value").cast('double')).withColumn("dateYear", col("dateYear").cast('integer'))
display(factTable2)

regionCode,indicatorCode,dateYear,value
ABW,EG.ELC.ACCS.ZS,1960,
AFE,EG.ELC.ACCS.ZS,1960,
AFG,EG.ELC.ACCS.ZS,1960,
AFW,EG.ELC.ACCS.ZS,1960,
AGO,EG.ELC.ACCS.ZS,1960,
ALB,EG.ELC.ACCS.ZS,1960,
AND,EG.ELC.ACCS.ZS,1960,
ARB,EG.ELC.ACCS.ZS,1960,
ARE,EG.ELC.ACCS.ZS,1960,
ARG,EG.ELC.ACCS.ZS,1960,


In [0]:
# Count the number of rows in the tables
dimCountry.count()
dimIndicator.count()
dimDate.count()
factTable.count()
factTable2.count()

In [0]:
# Finally, save the tables! Because factTable2 is significantly faster than factTable1, we pick factTable2 as our primary Fact table.
dimCountry.write.mode("overwrite").saveAsTable("DimensionTableCountryCode")
dimIndicator.write.mode("overwrite").saveAsTable("DimensionTableIndicatorCode")
dimDate.write.mode("overwrite").saveAsTable("DimensionTableDate")
factTable.write.mode("overwrite").saveAsTable("FactTableBackup")
factTable2.write.mode("overwrite").saveAsTable("FactTable")