In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import udf, avg

In [None]:
us = {'Alabama': 'AL', 'Alaska': 'AK', 'American Samoa': 'AS','Arizona': 'AZ', 'Arkansas': 'AR', 'California': 'CA', 'Colorado': 'CO', 'Connecticut': 'CT', 'Dakota': 'DK', 'Delaware': 'DE', 'District of Columbia': 'DC', 'Florida': 'FL', 'Georgia': 'GA', 'Guam': 'GU', 'Hawaii': 'HI', 'Idaho': 'ID', 'Illinois': 'IL', 'Indiana': 'IN', 'Iowa': 'IA', 'Kansas': 'KS', 'Kentucky': 'KY', 'Louisiana': 'LA', 'Maine': 'ME', 'Maryland': 'MD', 'Massachusetts': 'MA', 'Michigan': 'MI', 'Minnesota': 'MN', 'Mississippi': 'MS', 'Missouri': 'MO', 'Montana': 'MT', 'Nebraska': 'NE', 'Nevada': 'NV', 'New Hampshire': 'NH', 'New Jersey': 'NJ', 'New Mexico': 'NM', 'New York': 'NY', 'North Carolina': 'NC', 'North Dakota': 'ND', 'Northern Mariana Islands': 'MP', 'Ohio': 'OH', 'Oklahoma': 'OK', 'Oregon': 'OR', 'Orleans': 'OL', 'Pennsylvania': 'PA', 'Philippine Islands': 'PI', 'Puerto Rico': 'PR', 'Rhode Island': 'RI', 'South Carolina': 'SC', 'South Dakota': 'SD', 'Tennessee': 'TN', 'Texas': 'TX', 'Utah': 'UT', 'Vermont': 'VT', 'Virgin Islands': 'VI', 'Virginia': 'VA', 'Washington': 'WA', 'West Virginia': 'WV', 'Wisconsin': 'WI', 'Wyoming': 'WY'}
def state(mapping):
    def state_(col):
        return mapping.get(col)
    return udf(state_, StringType())

In [None]:
ss = SparkSession.builder.config('spark.driver.memory', '8g').getOrCreate()

In [None]:
test_mode = True

s3 = {'rate': 's3a://msds-durian-candy/insurance/Rate.csv.gz',
      'plan': 's3a://msds-durian-candy/insurance/PlanAttributes.csv.gz',
      'service_area': 's3a://msds-durian-candy/insurance/ServiceArea.csv.gz',
      'census': 's3a://msds-durian-candy/census/acs2015_county_data.csv.gz',}

local = {'rate': '../data/s3/insurance/Rate.csv',
         'plan': '../data/s3/insurance/PlanAttributes.csv',
         'service_area': '../data/s3/insurance/ServiceArea.csv',
         'census': '../data/s3/census/acs2015_county_data.csv'}


source = local if test_mode else s3

rate = ss.read.csv(source['rate'], header=True) \
              .filter('IndividualRate > 1 and IndividualRate < 9999') \
              .drop('RowNumber') \
              .cache()
plan = ss.read.csv(source['plan'], header=True) \
              .select('StandardComponentId', 'PlanType', 'BusinessYear', 'StateCode', 'ServiceAreaId', 'IssuerId',
                      'BeginPrimaryCareCostSharingAfterNumberOfVisits', 'BeginPrimaryCareDeductibleCoinsuranceAfterNumberOfCopays',
                      'CSRVariationType', 'ChildOnlyOffering', 'OutOfServiceAreaCoverage') \
              .withColumnRenamed('StandardComponentId', 'PlanId') \
              .distinct() \
              .cache()
service_area = ss.read.csv(source['service_area'], header=True) \
                      .select('BusinessYear','StateCode', 'IssuerId', 'ServiceAreaId', 'CoverEntireState') \
                      .distinct() \
                      .cache()

census = ss.read.csv(source['census'], header=True) \
                .withColumn('StateCode', state(us)('State')) \
                .drop('CensusId', 'State', 'County')
census = census.groupBy('StateCode') \
               .agg(*[avg(x).alias(x) for x in census.drop('StateCode').schema.names]) \
               .cache()

In [None]:
plan_service_area = plan.join(service_area, ['BusinessYear', 'StateCode', 'ServiceAreaId', 'IssuerId'], 'inner').cache()
insurance = rate.join(plan_service_area, ['PlanId', 'BusinessYear', 'StateCode'], 'left_outer').cache()
df = insurance.join(census, 'StateCode', 'left_outer').cache()

In [None]:
df.show(5, vertical=True)

In [None]:
ss.stop()