In [0]:
#define functions needed to create lifetables for men and women for year <year> over months 1-<month>

In [0]:
def makelifetable(sample, year):

    #filter only those who were members in given year
    #for our purposes this means just removing those who died before <year>
    sample = sample.withColumn('death_year', spark_funcs.year('date_of_death'))
    sample = sample.filter( ~( (spark_funcs.col('death_indicator') == spark_funcs.lit(True)) & (spark_funcs.col('death_year') < spark_funcs.lit(year)) ) )

    #find age at jan 1 of year
    comp_str = str(year) + '-01-01'
    sample = sample.withColumn('age_comp', spark_funcs.to_date(spark_funcs.lit(comp_str)))
    sample = sample.withColumn('age_jan_year', spark_funcs.floor(spark_funcs.months_between(spark_funcs.col('age_comp'), spark_funcs.col("date_of_birth"))/spark_funcs.lit(12)))


    #create living, dead roster
    living = sample.groupBy(['age_jan_year', 'gender']).count().sort(['age_jan_year', 'gender']).withColumnRenamed('count', 'alive_jan1')
    dead = sample.filter((spark_funcs.col('death_indicator') == spark_funcs.lit(True)) & (spark_funcs.col('death_year') == spark_funcs.lit(year))).groupBy(['age_jan_year', 'gender']).count().sort(['age_jan_year', 'gender']).withColumnRenamed('count', 'died')

    #create blank data frame to ensure all ages and genders represented in final output (useful if there would be 0 Men age 102, for example)
    data = []
    for x in range(60, 104):
        data.append([x, 'F'])
        data.append([x, 'M'])
    
    columns = ["age_jan_year", "gender"]
    complete_ages = spark.createDataFrame(data, columns)
    complete_ages = complete_ages.join(living, on=['age_jan_year', 'gender'], how='left')
    complete_ages = complete_ages.join(dead, on=['age_jan_year', 'gender'], how='left')

    #if no living in row, make living and dead = 1
    #if no dead in row, make dead = 0
    complete_ages = complete_ages.withColumn("died", spark_funcs.when(spark_funcs.col('alive_jan1').isNull(), 1).otherwise(spark_funcs.col('died')))
    complete_ages = complete_ages.withColumn("alive_jan1", spark_funcs.when(spark_funcs.col('alive_jan1').isNull(), 1).otherwise(spark_funcs.col('alive_jan1')))
    complete_ages = complete_ages.withColumn("died", spark_funcs.when(spark_funcs.col('died').isNull(), 1).otherwise(spark_funcs.col('died')))


    #Create in age death rate
    complete_ages = complete_ages.withColumn('q', spark_funcs.col('died') / spark_funcs.col('alive_jan1'))

    men = complete_ages.filter(spark_funcs.col('gender') == spark_funcs.lit('M'))
    women = complete_ages.filter(spark_funcs.col('gender') == spark_funcs.lit('F'))

    df_dict = {'men': men, 'women': women}

    #Loop over genders to create final output
    for x in range(0,2):
        temp_df = df_dict[list(df_dict)[x]]
        #assert order here
        temp_df = temp_df.orderBy('age_jan_year')

        #create I
        temp_df = temp_df.withColumn( 'I', spark_funcs.when( spark_funcs.col('age_jan_year') == spark_funcs.lit(60), 100000).otherwise(0))
        collected = temp_df.collect()
        old_q = collected[0].__getitem__('q')
        old_I = 100000
        for i in range(60, 104):

            old_I =  max(0, round(old_I*(1-old_q)))
            if i < 103:
                old_q = collected[i-59].__getitem__('q')
            temp_df = temp_df.withColumn( 'I', spark_funcs.when( spark_funcs.col('age_jan_year') == spark_funcs.lit(i+1), old_I).otherwise(spark_funcs.col('I')))

        #Create D
        temp_df = temp_df.withColumn( 'D', spark_funcs.when( spark_funcs.col('age_jan_year') == spark_funcs.lit(103), spark_funcs.col('I')).otherwise(0))
        collected = temp_df.collect()
        for i in range(42, -1, -1):
            old_I = collected[i+1].__getitem__('I')
            curr_I = collected[i].__getitem__('I')
            temp_df = temp_df.withColumn( 'D', spark_funcs.when( spark_funcs.col('age_jan_year') == spark_funcs.lit(i+60), curr_I - old_I).otherwise(spark_funcs.col('D')))


        #Create LY
        temp_df = temp_df.withColumn('LY', spark_funcs.col('I') - 0.5*spark_funcs.col('D'))

        #Create TY
        temp_df = temp_df.withColumn( 'TY', spark_funcs.when( spark_funcs.col('age_jan_year') == spark_funcs.lit(103), spark_funcs.col('LY')).otherwise(0))
        collected = temp_df.collect()
        old_TY = collected[43].__getitem__('TY')
        for i in range(42, -1, -1):
            old_TY = old_TY + collected[i].__getitem__('LY')
            temp_df = temp_df.withColumn( 'TY', spark_funcs.when( spark_funcs.col('age_jan_year') == spark_funcs.lit(i+60), old_TY ).otherwise(spark_funcs.col('TY')))

        #Create LE
        temp_df = temp_df.withColumn( 'LE', spark_funcs.when( spark_funcs.col('I') > spark_funcs.lit(0), spark_funcs.col('TY') / spark_funcs.col('I')).otherwise(0))

        #display output for download
        display(temp_df.sort("age_jan_year"))

