# Statistics for Drug Purchases

In [1]:
from pyspark.sql import functions as fn
from pyspark.sql import Window
from IPython.display import display
from datetime import datetime
import collections

root_path = "/user/silva/output/fall"

patients_path = root_path + "patients"
diagnoses_path = root_path + "diagnoses"
drugs_path = root_path + "drug-purchases"
fractures_path = root_path + "fractures"

Dates = collections.namedtuple('dates', ['lower_bound', 'upper_bound'])
dates = Dates(
    lower_bound = datetime(2006, 1, 1, 0, 0, 0),
    upper_bound = datetime(2011, 1, 1, 0, 0, 0)
)

def filter_date(df):
    return df.where(
        (fn.col("start") >= fn.lit(dates.lower_bound)) &\
        (fn.col("start") < fn.lit(dates.upper_bound))
    )

drugs = filter_date(sqlContext.read.parquet(drugs_path)).cache()
diagnoses = filter_date(sqlContext.read.parquet(diagnoses_path)).cache()
fractures = filter_date(sqlContext.read.parquet(fractures_path))
events = drugs.union(diagnoses).cache()

affected_patients = fractures.select("patientID").distinct().cache()

print("affected patients: {}".format(affected_patients.count()))
display(events.agg(fn.min("start"), fn.max("start")).toPandas())

affected patients: 32912


Unnamed: 0,min(start),max(start)
0,2009-10-02,2010-12-31


In [2]:
patients = drugs.select("patientID").distinct().cache()
patients.count()

2203891

In [3]:
class MyDF:
    window = Window.partitionBy("patientID")
    sorted_window = window.orderBy("start")
    
    def __init__(self, df, drug_family_col = "value"):
        self.df = df
        self.family_col = drug_family_col
    
    @staticmethod
    def start_gap_col_name(n):
        return "with_{0}mo_start_gap".format(n)
    
    @staticmethod
    def start_gap_col(n):
        return fn.col(MyDF.start_gap_col_name(n))
    
    @staticmethod
    def trackloss_col_name(n, kind):
        return "with_{}_trackloss_{}mo".format(kind, n)
    
    @staticmethod
    def trackloss_col(n, kind):
        return fn.col(MyDF.trackloss_col_name(n, kind))
        
    """
    Groups by (PatientID, family_col) to add intermediate aggregations and prepare for the final results.
    """
    def first_level_agg(self, gaps, tl_sizes=[3]):
        basic_aggs = [
            fn.count('*').alias("total_events")
        ]
        start_gap_aggs = [fn.first(MyDF.start_gap_col(n)).alias(MyDF.start_gap_col_name(n)) for n in gaps]
        soft_trackloss_aggs = [fn.sum(MyDF.trackloss_col(s, "soft")).alias(MyDF.trackloss_col_name(s, "soft")) for s in tl_sizes]
        hard_trackloss_aggs = [fn.sum(MyDF.trackloss_col(s, "hard")).alias(MyDF.trackloss_col_name(s, "hard")) for s in tl_sizes]
        all_aggs = basic_aggs + start_gap_aggs + soft_trackloss_aggs + hard_trackloss_aggs
        result_df = self.df.groupBy("patientID", self.family_col).agg(*all_aggs)
        return MyDF(result_df, self.family_col)
    
    """
    Groups by family_col, adding aggregations for counting patients with start gaps and tracklosses.
    """
    def aggregate_by_family(self, gaps, tl_sizes=[3]):
        basic_aggs = [
            fn.sum("total_events").alias("total_events"),
            fn.count("*").alias("distinct_patients")
        ]
        start_gap_aggs = [fn.sum(MyDF.start_gap_col(n)).alias(MyDF.start_gap_col_name(n)) for n in gaps]
        soft_trackloss_aggs = [fn.count(fn.when(MyDF.trackloss_col(s, "soft") >= 1, 1)).alias(MyDF.trackloss_col_name(s, "soft")) for s in tl_sizes]
        hard_trackloss_aggs = [fn.count(fn.when(MyDF.trackloss_col(s, "hard") >= 1, 1)).alias(MyDF.trackloss_col_name(s, "hard")) for s in tl_sizes]
        all_aggs = basic_aggs + start_gap_aggs + soft_trackloss_aggs + hard_trackloss_aggs
        return MyDF(self.df.groupBy(self.family_col).agg(*all_aggs).orderBy(self.family_col), self.family_col)
    
    """
    Groups by family and a single trackloss column, in order to count the number of patients for each trackloss count.
    """
    def aggregate_by_trackloss_count(self, n=3, kind="soft"):
        return MyDF(self.df\
            .groupBy(self.family_col, MyDF.trackloss_col(n, kind))\
            .agg(fn.count("*").alias("num_patients"))\
            .withColumnRenamed(MyDF.trackloss_col_name(n, kind), "{}_{}mo_trackloss_count".format(kind, n))\
            .orderBy(self.family_col, MyDF.trackloss_col(n, kind)), self.family_col
        )
    
    """
    Adds a column with a value 1 if the patient had a start_gap, or 0 if he didn't.
    
    @param n: the number of months without events after start date that define a "start gap"
    """
    def with_start_gap(self, n):
        new_name = MyDF.start_gap_col_name(n)
        beginning_plus_n = fn.add_months(fn.lit(dates.lower_bound), n)
        is_not_within_n = fn.when(fn.col("start") < beginning_plus_n, 0).otherwise(1)
        return MyDF(self.df.withColumn(new_name, fn.min(is_not_within_n).over(MyDF.window)), self.family_col)
    
    """
    Adds a column with a value 1 if the patient had a trackloss, or 0 if he didn't.
    
    @param kind: "soft" | "hard" defining the type of trackloss
    """
    def _find_tracklosses(self, new_window, n, kind):
        new_sorted_window = new_window.orderBy("start")
        next_date = fn.lead(fn.col("start"), 1).over(new_sorted_window.rowsBetween(1, 1))
        date_diff = fn.months_between(fn.col("next_date"), fn.col("start"))
        is_trackloss = fn.when(date_diff >= n, 1).otherwise(0)
        result = MyDF(
            self.df
                .withColumn("next_date", next_date)
                .withColumn(MyDF.trackloss_col_name(n, kind), is_trackloss)
                .drop("next_date")
            , self.family_col
        )
        return result
    
    def with_soft_trackloss(self, n):
        new_window = Window.partitionBy("patientID", self.family_col)
        return self._find_tracklosses(new_window, n, "soft")
    
    def with_hard_trackloss(self, n):
        new_window = Window.partitionBy("patientID")
        return self._find_tracklosses(new_window, n, "hard")
    
    def filter_only_affected(self):
        return MyDF(self.df.join(affected_patients, "patientID"), self.family_col)
        
    def cache(self):
        return MyDF(self.df.cache(), self.family_col)
    
    def repartition(self): 
        return MyDF(self.df.repartition("patientID", self.family_col), self.family_col)
        
    def filter_category(self, category):
        return MyDF(self.df.where(fn.col("category") == category), self.family_col)
    
    def distinct(self):
        return MyDF(self.df.distinct(), self.family_col)\
    
    def drop(self, col_name):
        return MyDF(self.df.drop(col_name), self.family_col)
        
    def show(self, lines=20):
        self.df.show(lines)
        
    def toPandas(self):
        return self.df.toPandas()

In [4]:
start_gaps = [1, 2, 3]
trackloss_lengths = [1, 2, 3]
def run(family_col="value"):
    base_df = MyDF(drugs, family_col)\
        .with_start_gap(1)\
        .with_start_gap(2)\
        .with_start_gap(3)\
        .with_hard_trackloss(1)\
        .with_hard_trackloss(2)\
        .with_hard_trackloss(3)\
        .with_soft_trackloss(1)\
        .with_soft_trackloss(2)\
        .with_soft_trackloss(3)\
        .cache()
        
    return base_df

## Stats by family

The following tables show the number of patients who had start gaps and tracklosses for each drug family.

In [5]:
for col in ["category", "value"]:
    first_level_agg =run(col).first_level_agg(start_gaps, trackloss_lengths)
    by_family_all = first_level_agg.aggregate_by_family(start_gaps, trackloss_lengths)
    by_family_affected_patients = first_level_agg.filter_only_affected().aggregate_by_family(start_gaps, trackloss_lengths)

    print("All patients: {}".format(patients.count()))
    if col == "category":
        for i in [1, 2, 3]:
            by_family_all = by_family_all.drop("with_soft_trackloss_{}mo".format(i))
            by_family_affected_patients = by_family_affected_patients.drop("with_soft_trackloss_{}mo".format(i))
    
    display(by_family_all.toPandas())
    print("Only affected: {}".format(affected_patients.count()))
    display(by_family_affected_patients.toPandas())

All patients: 2203891


Unnamed: 0,category,total_events,distinct_patients,with_1mo_start_gap,with_2mo_start_gap,with_3mo_start_gap,with_hard_trackloss_1mo,with_hard_trackloss_2mo,with_hard_trackloss_3mo
0,drug,36028506,2203891,2203891,2203891,2203891,1946248,805864,427480


Only affected: 32912


Unnamed: 0,category,total_events,distinct_patients,with_1mo_start_gap,with_2mo_start_gap,with_3mo_start_gap,with_hard_trackloss_1mo,with_hard_trackloss_2mo,with_hard_trackloss_3mo
0,drug,475239,25499,25499,25499,25499,22700,9929,4586


All patients: 2203891


Unnamed: 0,value,total_events,distinct_patients,with_1mo_start_gap,with_2mo_start_gap,with_3mo_start_gap,with_soft_trackloss_1mo,with_soft_trackloss_2mo,with_soft_trackloss_3mo,with_hard_trackloss_1mo,with_hard_trackloss_2mo,with_hard_trackloss_3mo
0,Antidepresseurs,3769164,503433,503433,503433,503433,381147,141331,62573,338475,76661,31274
1,Antihypertenseurs,24144660,1682284,1682284,1682284,1682284,1570642,663906,352241,1454819,542980,285282
2,Hypnotiques,7140832,1061783,1061783,1061783,1061783,730443,426945,239965,642956,212963,111005
3,Neuroleptiques,973850,130793,130793,130793,130793,91946,42696,20600,74251,15776,6395


Only affected: 32912


Unnamed: 0,value,total_events,distinct_patients,with_1mo_start_gap,with_2mo_start_gap,with_3mo_start_gap,with_soft_trackloss_1mo,with_soft_trackloss_2mo,with_soft_trackloss_3mo,with_hard_trackloss_1mo,with_hard_trackloss_2mo,with_hard_trackloss_3mo
0,Antidepresseurs,70358,9561,9561,9561,9561,7327,3102,1393,6248,1578,623
1,Antihypertenseurs,265546,19534,19534,19534,19534,17674,8087,3726,15028,5476,2491
2,Hypnotiques,122545,15852,15852,15852,15852,11601,6545,3381,9624,3102,1415
3,Neuroleptiques,16790,3077,3077,3077,3077,1846,948,466,1382,329,130


## Stats by trackloss count

The following tables show the distribution of patients for each number of tracklosses.

In [6]:
print("All patients: {}".format(patients.count()))
display(first_level_agg.aggregate_by_trackloss_count().df.toPandas())

print("Only affected: {}".format(affected_patients.count()))
display(first_level_agg.filter_only_affected().aggregate_by_trackloss_count().df.toPandas())

All patients: 2203891


Unnamed: 0,value,soft_3mo_trackloss_count,num_patients
0,Antidepresseurs,0,440860
1,Antidepresseurs,1,58052
2,Antidepresseurs,2,4454
3,Antidepresseurs,3,67
4,Antihypertenseurs,0,1330043
5,Antihypertenseurs,1,270493
6,Antihypertenseurs,2,74928
7,Antihypertenseurs,3,6820
8,Hypnotiques,0,821818
9,Hypnotiques,1,206580


Only affected: 32912


Unnamed: 0,value,soft_3mo_trackloss_count,num_patients
0,Antidepresseurs,0,8168
1,Antidepresseurs,1,1323
2,Antidepresseurs,2,70
3,Antihypertenseurs,0,15808
4,Antihypertenseurs,1,3272
5,Antihypertenseurs,2,432
6,Antihypertenseurs,3,22
7,Hypnotiques,0,12471
8,Hypnotiques,1,3072
9,Hypnotiques,2,303


In [7]:
print("All patients: {}".format(patients.count()))
display(first_level_agg.aggregate_by_trackloss_count(kind="hard").df.toPandas())

print("Only affected: {}".format(affected_patients.count()))
display(first_level_agg.filter_only_affected().aggregate_by_trackloss_count(kind="hard").df.toPandas())

All patients: 2203891


Unnamed: 0,value,hard_3mo_trackloss_count,num_patients
0,Antidepresseurs,0,472159
1,Antidepresseurs,1,29803
2,Antidepresseurs,2,1456
3,Antidepresseurs,3,15
4,Antihypertenseurs,0,1397002
5,Antihypertenseurs,1,220595
6,Antihypertenseurs,2,59242
7,Antihypertenseurs,3,5445
8,Hypnotiques,0,950778
9,Hypnotiques,1,99265


Only affected: 32912


Unnamed: 0,value,hard_3mo_trackloss_count,num_patients
0,Antidepresseurs,0,8938
1,Antidepresseurs,1,605
2,Antidepresseurs,2,18
3,Antihypertenseurs,0,17043
4,Antihypertenseurs,1,2197
5,Antihypertenseurs,2,278
6,Antihypertenseurs,3,16
7,Hypnotiques,0,14437
8,Hypnotiques,1,1328
9,Hypnotiques,2,87
