In [78]:
from os import listdir
from os.path import isfile, join
from pyspark.sql import SQLContext
from pyspark.storagelevel import StorageLevel
from pyspark.sql.functions import pandas_udf, PandasUDFType, mean
from datetime import date
import calendar


### Get monthly Claim files

In [79]:
datafiles = listdir("../MonthlySpark")[:-1]

### Setup running environment

In [80]:
sqlContext = SQLContext(sc)

sc.setCheckpointDir(dirName="../MonthlySpark/checkpoint")

### Daily dose calculation

In [81]:
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def daily_dose(startDays, endDays, effectiveDays, doses):
    days = [0 for i in range(31)]

    return doses.max()
    for (i, startDay) in startDays.iteritems():
        
        first_day = startDay.day
        last_day = endDays[i].day
        
        last_day_of_month = date(effectiveDays[i].year, effectiveDays[i].month ,calendar.monthrange(effectiveDays[i].year, effectiveDays[i].month)[1])
        
        if first_day < effectiveDays[i]:
            first_day = effectiveDays[i].day
        
        if endDays[i] > last_day_of_month:
            last_day = last_day_of_month.day
        
        for j in range(first_day, last_day):
            days[j] += doses[i]
    
    return max(days)


### Iterate over files

In [None]:
for dataFile in datafiles:
    print("Processing file: " + dataFile)
    df = sqlContext.read.parquet("../MonthlySpark/" + dataFile)
    grp = df.groupBy(df.PATIENT_ID)
    aggDf = grp.agg(daily_dose(df.START_DATE, df.END_DATE, df.Effective_Month, df.DOSE))\
        .checkpoint().persist(StorageLevel.DISK_ONLY)
    print("Collecting statistics for file: " + dataFile)
    doses = aggDf.select(aggDf[1].alias("DOSE")).checkpoint().persist(StorageLevel.DISK_ONLY)
    totalDoseStats = doses.describe("DOSE")
    totalDoseStats = totalDoseStats.select(totalDoseStats[0].alias("Total"), totalDoseStats.DOSE)
    moreThan50Under90 = doses.filter((doses.DOSE >= 50.0) & (doses.DOSE < 90.0)).describe("DOSE")
    moreThan50Under90 = moreThan50Under90.select(moreThan50Under90[0].alias("STATS50"), moreThan50Under90.DOSE.alias("DOSE50"))
    moreThen90 = doses.filter(doses.DOSE >= 90.0).describe("DOSE")
    moreThen90 = moreThen90.select(moreThen90[0].alias("STATS90"), moreThen90.DOSE.alias("DOSE90"))
    totalDoseStats.join(moreThan50Under90, totalDoseStats.Total == moreThan50Under90.STATS50)\
                    .join(moreThen90, totalDoseStats.Total == moreThen90.STATS90)\
                    .toPandas().to_csv("../MonthlyStats/"+ dataFile.split(".")[0] + ".csv")
    
    
    moreThen90.unpersist()
    moreThan50Under90.unpersist()
    totalDoseStats.unpersist()
    doses.unpersist()
    aggDf.unpersist()
    df.unpersist()
    
    del moreThen90
    del moreThan50Under90
    del totalDoseStats
    del doses
    del aggDf
    del grp
    del df
    