In [1]:
from pyspark.sql import functions as fn
from pyspark.sql import Window
from IPython.display import display
from datetime import datetime
import collections

root_path = "/user/silva/output/fall"

patients_path = root_path + "patients"
diagnoses_path = root_path + "diagnoses"
drugs_path = root_path + "drug-purchases"
fractures_path = root_path + "fractures"

Dates = collections.namedtuple('dates', ['lower_bound', 'upper_bound'])
dates = Dates(
    lower_bound = datetime(2006, 1, 1, 0, 0, 0),
    upper_bound = datetime(2011, 1, 1, 0, 0, 0)
)

def filter_date(df):
    return df.where(
        (fn.col("start") >= fn.lit(dates.lower_bound)) &
        (fn.col("start") < fn.lit(dates.upper_bound))
    )

drugs = filter_date(sqlContext.read.parquet(drugs_path)).cache()
diagnoses = filter_date(sqlContext.read.parquet(diagnoses_path)).cache()
fractures = filter_date(sqlContext.read.parquet(fractures_path))
events = drugs.union(diagnoses).cache()
patients = sqlContext.read.parquet(patients_path).cache()

affected_patients = fractures.select("patientID").distinct().cache()

print("affected patients: {}".format(affected_patients.count()))
display(events.agg(fn.min("start"), fn.max("start")).toPandas())

affected patients: 32912


Unnamed: 0,min(start),max(start)
0,2009-10-02,2010-12-31


In [2]:
drugs.select("value").distinct().cache().show()

+-----------------+
|            value|
+-----------------+
|Antihypertenseurs|
|      Hypnotiques|
|  Antidepresseurs|
|   Neuroleptiques|
+-----------------+



In [3]:
from itertools import chain

class MyDF(object):
    
    def __init__(self, patients, drugs, cohort_name, max_age=150):
        self.df = drugs.join(patients, "patientID", "inner").cache()
        self.cohort_name = cohort_name
        self.bucket_mapping = self._get_string_maps(max_age)
        
    def _get_string_maps(self, max_age):
        age_lists = range(0, max_age, 5)
        buckets = zip(age_lists[:-1], age_lists[1:])
        string_maps = {i: "[{}, {}[".format(bucket[0], bucket[1]) for (i, bucket) in enumerate(buckets)}
        return string_maps
        
    def _add_reference_date(self):
        self.df = (
            self.df.withColumn("referenceDate", fn.trunc(fn.col('start'), 'year'))
        )
        
        
    def _add_age_in_months(self):
        try:
            self.df = self.df.withColumn("ageInMonths", fn.months_between(fn.col("referenceDate"), fn.col("birthDate"))) 
        except:
            self._add_reference_date()
            self.df = self.df.withColumn("ageInMonths", fn.months_between(fn.col("referenceDate"), fn.col("birthDate")))
        
    def add_age(self):
        try:
            self.df = self.df.withColumn("age", fn.expr("ageInMonths div 12"))
        except:
            self._add_age_in_months()
            self.df = self.df.withColumn("age", fn.expr("ageInMonths div 12"))
            
    def add_age_bucket(self):
        try:
            self.df = self.df.withColumn("ageBucket", fn.expr("ageInMonths div (12*5)"))
        except:
            self._add_age_in_months()
            self.df = self.df.withColumn("ageBucket", fn.expr("ageInMonths div (12*5)"))

In [4]:
from IPython.display import display
import seaborn as sns

import matplotlib.ticker as ticker

sns.set_style("whitegrid")
sns.set_context("poster")


def boxes_mapping():
    boxes_list = range(0, 1000, 6)
    buckets = zip(boxes_list[:-1], boxes_list[1:])
    string_maps = {i: "[{}, {}[".format(bucket[0], bucket[1]) for (i, bucket) in enumerate(buckets)}
    return string_maps


def age_mapping(max_age):
    age_lists = range(0, max_age, 5)
    buckets = zip(age_lists[:-1], age_lists[1:])
    string_maps = {i: "[{}, {}[".format(bucket[0], bucket[1]) for (i, bucket) in enumerate(buckets)}
    return string_maps

def special_boxes_mapping():
    return {i: text for (i, text) in enumerate(["1 Achat", "2 Achats", "3 à 12 Achats",
                                                       "Plus de 12 Achats"])}


def count_box_to_bucket(number_boxes):
    return number_boxes // 6


class Stats(object):
    
    def __init__(self, patients, drugs, drugs_name):
        self.myDF = MyDF(patients, drugs, drugs_name)
        self.myDF.add_age_bucket()
        self.stats = self.myDF.df.groupby("patientID",  "ageBucket", "gender").count().toPandas()
        self.drugs_name = drugs_name
        self.boxes_mapping = boxes_mapping()
        self.age_mapping = age_mapping(150)
        self.special_boxes_mapping = special_boxes_mapping()
        
        def _special_box_buckets(number_boxes):
            if number_boxes == 1:
                return 0
            if number_boxes == 2:
                return 1
            if 2 < number_boxes < 13:
                return 2
            else:
                return 3
        
        r = self.stats["count"].apply(lambda x: x // 6)
        self.stats["boxBucket"] = r
        
        s = self.stats["count"].apply(_special_box_buckets)
        self.stats["specialBucket"] = s
        
        self.filtered_stats = self.stats[(self.stats.ageBucket < 20) & (self.stats.ageBucket > 12) & (self.stats["count"] < 50)].copy()
        
    def distribution_box(self, logscale=False, percentage=False):
        if percentage:
            ax = sns.barplot(x="count", y="count", data=self.stats,
                             estimator=lambda x: len(x) / len(self.stats) * 100)
            ax.set_ylabel("Pourcentage sur population")
        else:
            ax = sns.countplot(x="count", data=self.stats)
            ax.set_ylabel("Nombre de patient")
        
        ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        ax.xaxis.set_major_formatter(ticker.ScalarFormatter())

        
        ax.set_xlabel("Nombre d'achat sur une annee pour un patient")
        
        ax.set_title("Distribution de nombre d'achat de {} par an-patient".format(self.drugs_name))
        
        if logscale:
            ax.set_yscale("log", nonposy='clip')
            
        return ax
        
    def distribution_box_bucket(self, special=False, logscale=False, percentage=False):
        criterion = "specialBucket" if special else "boxBucket"
        
        if percentage:
            ax = sns.barplot(x=criterion, y=criterion, data=self.stats,
                             estimator=lambda x: len(x) / len(self.stats) * 100)
            ax.set_ylabel("Pourcentage sur population")
        else:
            ax = sns.countplot(x=criterion, data=self.stats)
            ax.set_ylabel("Nombre de patient")
        
        mapping_to_use = self.special_boxes_mapping if special else self.boxes_mapping
        x_tickslabels = [mapping_to_use[int(tick.get_text())] for tick in ax.get_xticklabels()]
        rotation = 0 if special else 90
        ax.set_xticklabels(x_tickslabels, rotation=rotation)
        
        ax.set_xlabel("Nombre d'achat sur une annee pour un patient")
        
        ax.set_title("Distribution de nombre d'achat de {} par an-patient".format(self.drugs_name))
        
        if logscale:
            ax.set_yscale("log", nonposy='clip')
        return ax
    
    def distribution_box_bucket_age_bucket(self, special=False, logscale=False, percentage=False):
        criterion = "specialBucket" if special else "boxBucket"
        if percentage:
            patient_counts = (self.filtered_stats.groupby(['ageBucket'])[criterion]
                                 .value_counts(normalize=True)
                                 .rename('percentage')
                                 .mul(100)
                                 .reset_index()
                                 .sort_values('ageBucket'))
            ax = sns.barplot(x="ageBucket", y="percentage", hue=criterion, data=patient_counts,
                            palette=sns.color_palette("Paired", n_colors=9))
            ax.set_ylabel("Pourcentage dans la tranche d'age")
            
        else:
            ax = sns.countplot(x="ageBucket", hue=criterion,
                               data=self.filtered_stats,
                               palette=sns.color_palette("Paired", n_colors=9))
            ax.set_ylabel("Nombre de patient")
        
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5),
                  title="Nombre d'achat par\nan par patient")
        
        mapping_to_use = self.special_boxes_mapping if special else self.boxes_mapping
        
        legend = ax.get_legend()
        [label.set_text(mapping_to_use[int(label.get_text())]) for label in legend.get_texts()]

        x_tickslabels = [self.age_mapping[int(tick.get_text())] for tick in ax.get_xticklabels()]
        rotation = 0 if special else 90
        ax.set_xticklabels(x_tickslabels, rotation=rotation)
        
        ax.set_xlabel("Tranche d'age du patient")
        
        ax.set_title("Distribution de nombre d'achat de {}\npar an-patient suivant les tranches d'age".format(self.drugs_name))
        
        if logscale:
            ax.set_yscale("log", nonposy='clip')
            
        return ax
    
    def distribution_box_bucket_gender(self, special=False, logscale=False, percentage=False):
        criterion = "specialBucket" if special else "boxBucket"
        if percentage:
            patient_counts = (self.filtered_stats.groupby(['gender'])[criterion]
                                 .value_counts(normalize=True)
                                 .rename('percentage')
                                 .mul(100)
                                 .reset_index()
                                 .sort_values('gender'))
            ax = sns.barplot(x="gender", y="percentage", hue=criterion, data=patient_counts,
                            palette=sns.color_palette("Paired", n_colors=9))
            ax.set_ylabel("Pourcentage selon le genre")
            
        else:
            ax = sns.countplot(x="gender", hue=criterion,
                               data=self.filtered_stats,
                               palette=sns.color_palette("Paired", n_colors=9))
            ax.set_ylabel("Nombre de patient")
        
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5),
                  title="Nombre d'achat par\nan par patient")
        
        mapping_to_use = self.special_boxes_mapping if special else self.boxes_mapping
        
        legend = ax.get_legend()
        [label.set_text(mapping_to_use[int(label.get_text())]) for label in legend.get_texts()]
        
        ax.set_xticklabels(["Homme", "Femme"])
        
        ax.set_xlabel("Genre")
        
        ax.set_title("Distribution de nombre d'achat de {}\npar an-patient suivant le genre".format(self.drugs_name))
        
        if logscale:
            ax.set_yscale("log", nonposy='clip')
            
        return ax
    

In [5]:
sns.set_style("whitegrid")
sns.set_context("poster")

In [10]:
from itertools import product
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

def save_pdf(patients, drugs, drugs_class, description):
    stats = Stats(patients, drugs.where(fn.col("value") == drugs_class), "{}({})".format(drugs_class, description))
    
    with PdfPages('{}_{}_stats.pdf'.format(drugs_class, description)) as pdf:
        for percentage in [True, False]:
            if not percentage:
                for logscale in [True, False]:
                    fig = plt.figure()
                    stats.distribution_box(percentage=percentage, logscale=logscale)
                    plt.tight_layout()
                    pdf.savefig(fig)
            else:
                fig = plt.figure()
                stats.distribution_box(percentage=percentage)
                plt.tight_layout()
                pdf.savefig(fig)
        
        for special, percentage in product([True, False], [True, False]):
            if not percentage:
                for logscale in [True, False]:
                    fig = plt.figure()
                    stats.distribution_box_bucket(percentage=percentage, logscale=logscale,
                                                 special=special)
                    plt.tight_layout()
                    pdf.savefig(fig)
            else:
                fig = plt.figure()
                stats.distribution_box_bucket(percentage=percentage, special=special)
                plt.tight_layout()
                pdf.savefig(fig)
                
        for special, percentage in product([True, False], [True, False]):
            if not percentage:
                for logscale in [True, False]:
                    fig = plt.figure()
                    stats.distribution_box_bucket_age_bucket(percentage=percentage, logscale=logscale,
                                                 special=special)
                    plt.tight_layout()
                    plt.subplots_adjust(right=0.75)
                    pdf.savefig(fig)
            else:
                fig = plt.figure()
                stats.distribution_box_bucket_age_bucket(percentage=percentage, special=special)
                plt.tight_layout()
                plt.subplots_adjust(right=0.75)
                pdf.savefig(fig)
                
        for special, percentage in product([True, False], [True, False]):
            if not percentage:
                for logscale in [True, False]:
                    fig = plt.figure()
                    stats.distribution_box_bucket_gender(percentage=percentage, logscale=logscale,
                                                 special=special)
                    plt.tight_layout()
                    plt.subplots_adjust(right=0.75)
                    pdf.savefig(fig)
            else:
                fig = plt.figure()
                stats.distribution_box_bucket_gender(percentage=percentage, special=special)
                plt.tight_layout()
                plt.subplots_adjust(right=0.75)
                pdf.savefig(fig)

# Tout patients

In [12]:
drugs_classes = ["Antihypertenseurs", "Hypnotiques", "Antidepresseurs", "Neuroleptiques"]

for drug_class in drugs_classes:
    save_pdf(patients, drugs, drug_class, "tout_patients")



#  Affected patients : patients with fractures

In [15]:
affected = patients.join(affected_patients, "patientID", "inner").cache()

In [17]:
for drug_class in drugs_classes:
    save_pdf(affected, drugs, drug_class, "avec fracture")

