In [None]:
from decimal import Decimal
from itertools import islice
from collections import namedtuple
from datetime import date
from datetime import timedelta
from datetime import datetime
from dateutil.relativedelta import relativedelta
from dateutil.rrule import rrule, MONTHLY
from pyspark.sql.types import BooleanType, DateType, FloatType, IntegerType, TimestampType
from pyspark.sql.functions import udf, struct, date_add, expr, col, posexplode, months_between, trunc, datediff
from pyspark.storagelevel import StorageLevel



### Get monthly SAS files

In [None]:
from os import listdir
from os.path import isfile, join
datafiles = [f for f in listdir("../MonthlyData") if isfile(join("../MonthlyData", f))]

### Claim generation methods

In [None]:
create_start_date = udf(lambda svc_dt: date.fromisoformat(svc_dt[0:4] + "-" + svc_dt[4:6] + "-" + svc_dt[6:8]), DateType())

round_int = udf(lambda x: round(try_float(x)), IntegerType())

create_dose = udf(lambda quantity, daysDispence, strength, mmeFactor: (try_float(quantity)/(round(try_float(daysDispence)) + 1)) * try_float(strength) * try_float(mmeFactor), FloatType()) 

def try_float(value):
    try:
        return float(value)
    except ValueError:
        return 0
    except TypeError:
        return 0
    
def is_valid(daysSupplied):
    if not daysSupplied:
        return False
    if(daysSupplied == ''): # day supplyed empty
        return False
    if(round(try_float(daysSupplied)) == 0): # day supplyed == 0
        return False
    return True

def split_claim_by_effective_month(claim):
    months = [dt.strftime("%Y%m") for dt in rrule(MONTHLY, dtstart=claim.start_date, until=claim.end_date)]
    
    for m in months:
        yield Claim(m, claim.start_date, claim.end_date, claim.claim_id, claim.patient_id, claim.provider_id, claim.product_id, claim.quantity, claim.mme_factor, claim.strength, claim.daysDispence, claim.dose)

### Split data into monthly claim files

In [None]:
#sc.setCheckpointDir(dirName="../MonthlySpark/checkpoint")
#totalClaims = None
dataFilesToRun = datafiles[-1:]
for dataFile in dataFilesToRun:
    print("Processing file: " + dataFile)
    data = spark.read.load("../MonthlyData/" + dataFile, format="csv", sep="\t", inferSchema="false", header="true")
    filter_data_func = udf(is_valid, BooleanType())

    data = data.persist(StorageLevel.DISK_ONLY)

    data = data.filter(filter_data_func(data.DAYS_SUPPLY_CNT)).checkpoint().persist(StorageLevel.DISK_ONLY)\
            .withColumn('DAYS_SUPPLY_CNT_int', round_int(data.DAYS_SUPPLY_CNT))\
            .withColumn("START_DATE", create_start_date(data.SVC_DT)).checkpoint().persist(StorageLevel.DISK_ONLY)
    
    #data.count()
    data = data.withColumn('END_DATE', expr('date_add(START_DATE, DAYS_SUPPLY_CNT_int)'))\
        .withColumn('DOSE', ((col('DSPNSD_QTY')/(col('DAYS_SUPPLY_CNT_int')+1.0)) * col('STRENGTH') * col('MME_Conversion_Factor')))\
        .drop('SVC_DT', 'DSPNSD_QTY', 'DAYS_SUPPLY_CNT', 'MME_Conversion_Factor', 'STRENGTH', 'DAYS_SUPPLY_CNT_int')\
        .checkpoint().persist(StorageLevel.DISK_ONLY)
        
    #data.count()
    data = data.withColumn("monthsDiff", months_between(trunc("END_DATE",'month'), trunc("START_DATE", 'month')))\
        .withColumn("repeat", expr("split(repeat(',', monthsDiff), ',')"))\
        .checkpoint().persist(StorageLevel.DISK_ONLY)
        
    #data.count()
    data = data.select("*", posexplode("repeat").alias("Effective_Month", "val"))\
        .withColumn("Effective_Month", expr("trunc(add_months(START_DATE, Effective_Month), 'month')"))\
        .drop('monthsDiff','repeat','val')\
        .checkpoint().persist(StorageLevel.DISK_ONLY)
    #data.count()
    
    if not totalClaims:
        totalClaims = data
    else:
        totalClaims = totalClaims.union(data)
    
    totalClaims = totalClaims.checkpoint().persist(StorageLevel.DISK_ONLY)
    currentMonth = dataFile.split('_')[1].split('.')[0]
    currentMonth = date.fromisoformat(currentMonth[0:4] + "-" + currentMonth[4:6] + "-01")
    fileName = "../MonthlySpark/" + currentMonth.strftime('%Y%m') + ".parquet"
    print("Writing file: " + fileName + " for month: " + currentMonth.strftime('%Y%m'))
    claimsToWrite = totalClaims.where(expr('datediff(Effective_Month,cast(\"'+currentMonth.strftime('%Y-%m-%d')+'\" as date)) == 0'))
    claimsToWrite.checkpoint().persist(StorageLevel.DISK_ONLY)
        
    claimsToWrite.write.parquet(fileName)
    totalClaims = totalClaims.where(expr('datediff(Effective_Month,cast(\"'+currentMonth.strftime('%Y-%m-%d')+'\" as date)) <> 0'))
    claimsToWrite.unpersist()
    del claimsToWrite
    data.unpersist()
    del data
    
remainingMonths = sorted(totalClaims.select(totalClaims.Effective_Month).distinct().collect())
     
for currentMonth in remainingMonths:
    currentMonthStr = currentMonth.Effective_Month.strftime('%Y%m')
    currentMonthDayStr = currentMonth.Effective_Month.strftime('%Y-%m-%d')
    print("Processing month: " + currentMonthStr)
    fileName = "../MonthlySpark/" + currentMonthStr + ".parquet"
    claimsToWrite = totalClaims.where(expr('datediff(Effective_Month,cast(\"'+currentMonthDayStr+'\" as date)) == 0'))
    claimsToWrite.checkpoint().persist(StorageLevel.DISK_ONLY)
    
    print("Writing file: " + fileName + " for month: " + currentMonthStr)
    claimsToWrite.write.parquet(fileName)
    totalClaims = totalClaims.where(expr('datediff(Effective_Month,cast(\"'+currentMonthDayStr+'\" as date)) <> 0')).checkpoint().persist(StorageLevel.DISK_ONLY)
    claimsToWrite.unpersist()
    del claimsToWrite