# Site Clustering

In [None]:
import pandas as pd
from google.cloud import bigquery
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import umap
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import collections
# import plotly.offline as pyo

# pyo.init_notebook_mode()

In [None]:
curation_project_id = ""
pdr_project_id = ""
dataset_id = ""

## Load Site Attributes

In [None]:
site_attribute_query = f"""
    -- participant size
    WITH participant_size AS (
      SELECT  src_hpo_id, COUNT(*) total
      FROM `{curation_project_id}.{dataset_id}.unioned_ehr_person` p
      JOIN `{curation_project_id}.{dataset_id}._mapping_person` mp
        ON mp.person_id = p.person_id
      GROUP BY src_hpo_id
      ORDER BY total DESC
    ),
    -- outpatient/inpatient ratio
    visit_type_size AS (
      SELECT
        src_hpo_id, IFNULL(OP, 0) OP, IFNULL(IP, 0) IP,
        IFNULL(ER, 0) ER
      FROM (
        SELECT
          src_hpo_id, anc_visit_concept.concept_code visit_type, COUNT(*) total
        FROM `{curation_project_id}.{dataset_id}.unioned_ehr_visit_occurrence` vo
        JOIN `{curation_project_id}.{dataset_id}._mapping_visit_occurrence` mvo
          ON mvo.visit_occurrence_id = vo.visit_occurrence_id
        JOIN `{curation_project_id}.{dataset_id}.concept` desc_visit_concept
          ON desc_visit_concept.concept_id = vo.visit_concept_id
        JOIN `{curation_project_id}.{dataset_id}.concept_ancestor` ca
          ON ca.descendant_concept_id = vo.visit_concept_id
        JOIN `{curation_project_id}.{dataset_id}.concept` anc_visit_concept
          ON anc_visit_concept.concept_id = ca.ancestor_concept_id
            AND anc_visit_concept.concept_code IN ('OP', 'IP', 'ER')
        GROUP BY src_hpo_id, visit_type
        ORDER BY total DESC
      ) a
      PIVOT (SUM(total) FOR visit_type IN ('OP', 'IP', 'ER'))
    ),
    consortium AS (
    -- consortium
      SELECT
        LOWER(map.HPO_ID) src_hpo_id, hpo.name consortium
      FROM `{curation_project_id}.lookup_tables.hpo_site_id_mappings` map
      JOIN `{pdr_project_id}.rdr_ops_data_view.v_organization` org
        ON map.Org_ID = org.external_id
      JOIN `{pdr_project_id}.rdr_ops_data_view.v_hpo` hpo
        ON hpo.hpo_id = org.hpo_id
    ),
    median_drug_exposure_years_elapsed AS (
    -- median drug exposure
        SELECT
            src_hpo_id, AVG(DATE_DIFF(CURRENT_DATE(), drug_exposure_start_date, YEAR)) median_drug_exposure_years_elapsed
        FROM `{curation_project_id}.{dataset_id}.unioned_ehr_drug_exposure` de
        JOIN `{curation_project_id}.{dataset_id}._mapping_drug_exposure` mde
            ON mde.drug_exposure_id = de.drug_exposure_id
        GROUP BY src_hpo_id
            
    )
    -- combined attributes
    SELECT
      ps.src_hpo_id, ps.total participant_size,
      IP, OP, ER, con.consortium, median_drug_exposure_years_elapsed
    FROM participant_size ps
    JOIN visit_type_size vts
      ON vts.src_hpo_id = ps.src_hpo_id
    JOIN consortium con
      ON con.src_hpo_id = ps.src_hpo_id
    JOIN median_drug_exposure_years_elapsed mdes
        ON mdes.src_hpo_id = ps.src_hpo_id
    ORDER BY src_hpo_id
"""

client = bigquery.Client()
site_attributes = client.query(site_attribute_query).to_dataframe()

In [None]:
# Calculate more attributes
site_attributes['IP/OP ratio'] = site_attributes['IP'] / site_attributes['OP']
site_attributes['IP+ER'] = site_attributes['IP'] + site_attributes['ER']

In [None]:
site_attributes

## Clustering Class

In [None]:
class Cluster:
    cluster_names_map = None
    cluster_names = None

    def __init__(self,
                 data,
                 cluster_target_col=None,
                 cluster_attribute_cols=None,
                 n_clusters=3,
                 clustering_name=None):
        assert cluster_target_col is not None and cluster_attribute_cols is not None

        super().__init__()

        if type(cluster_attribute_cols) == str:
            cluster_attribute_cols = [cluster_attribute_cols]

        attribute_distributions = data[cluster_attribute_cols].to_numpy()

        print("Running KMeans...")
        kmeans = KMeans(n_clusters=n_clusters,
                        random_state=0).fit(attribute_distributions)

        self.data = data
        self.cluster_labels = kmeans.labels_
        self.cluster_targets = data[cluster_target_col]
        self.cluster_attribute_cols = cluster_attribute_cols
        self.attribute_distributions = attribute_distributions
        self.n_clusters = n_clusters
        self.clustering_name = clustering_name

    def plot_distributions(self,
                           yaxis_title=None,
                           xaxis_title=None,
                           title=None,
                           **kwargs):
        
        if title is None:
            title = f'{self.clustering_name} Cluster Distributions'
        cluster_labels = self.cluster_names if self.cluster_names is not None else self.cluster_labels
        fig = px.box(self.data,
                     y=self.cluster_attribute_cols,
                     facet_col=cluster_labels,
                     title=title)

        if yaxis_title and xaxis_title:
            fig.update_layout(yaxis_title=yaxis_title, xaxis_title=xaxis_title)

        return fig

    def pca(self, title=None, **kwargs):
        if title is None:
            title = f'{self.clustering_name} PCA'
        
        n_components = min(2, len(self.cluster_attribute_cols))
        pcs = PCA(n_components=n_components).fit_transform(
            self.attribute_distributions)
        
        print(pcs.shape)

        cluster_labels = self.cluster_names if self.cluster_names is not None else self.cluster_labels
        
        if n_components == 1:
            
            fig = px.scatter(x=pcs.flatten(),
                             color=cluster_labels,
                             text=self.cluster_targets,
                             title=title)
        else:
            fig = px.scatter(x=pcs[:, 0],
                             y=pcs[:, 1],
                             color=cluster_labels,
                             text=self.cluster_targets,
                             title=title)

        return fig

    def umap(self, title=None, **kwargs):
        if title is None:
            title = f'{self.clustering_name} Cluster UMAP'
            
        reducer = umap.UMAP(min_dist=0)
        umap_embedding = reducer.fit_transform(self.attribute_distributions)
        cluster_labels = self.cluster_names if self.cluster_names is not None else self.cluster_labels
        fig = px.scatter(x=umap_embedding[:, 0],
                         y=umap_embedding[:, 1],
                         color=cluster_labels,
                         text=self.cluster_targets,
                         title=title)

        return fig

    def assign_cluster_names(self, cluster_names):
        assert len(cluster_names) == self.n_clusters

        self.cluster_names_map = cluster_names
        self.cluster_names = [
            self.cluster_names_map[cluster_label]
            for cluster_label in self.cluster_labels
        ]

    def member_counts(self):
        if self.cluster_names is not None:
            return collections.Counter(self.cluster_names)
        else:
            return collections.Counter(self.cluster_labels)

## Cluster by Inpatient/Outpatient Ratio

In [None]:
ip_op_ratio_cluster = Cluster(site_attributes,
                              cluster_target_col='src_hpo_id',
                              cluster_attribute_cols='IP/OP ratio',
                              n_clusters=3,
                                  clustering_name='Inpatient/Outpatient Ratio')

In [None]:
ip_op_ratio_cluster.plot_distributions()

In [None]:
ip_op_ratio_cluster.pca()

In [None]:
ip_op_ratio_cluster.umap()

In [None]:
ip_op_ratio_cluster.assign_cluster_names(
    ['low-ip/op-ratio', 'mid-ip/op-ratio', 'high-ip/op-ratio'])

ip_op_ratio_cluster_member_counts = ip_op_ratio_cluster.member_counts()
ip_op_ratio_cluster_member_counts

## Cluster by Participant Size

In [None]:
participant_size_cluster = Cluster(site_attributes,
                              cluster_target_col='src_hpo_id',
                              cluster_attribute_cols='participant_size',
                              n_clusters=3,
                                  clustering_name='Participant Size')

In [None]:
participant_size_cluster.plot_distributions()

In [None]:
participant_size_cluster.pca()

In [None]:
participant_size_cluster.umap()

In [None]:
participant_size_cluster.assign_cluster_names(
    ['low-participant-size', 'high-participant-size', 'mid-participant-size'])

participant_size_cluster_member_counts = participant_size_cluster.member_counts()
participant_size_cluster_member_counts

## Cluster by Inpatient/Outpatient/Emergency Visit Magnitude

In [None]:
visit_magnitude_cluster = Cluster(site_attributes,
                              cluster_target_col='src_hpo_id',
                              cluster_attribute_cols=['IP', 'OP', 'ER'],
                              n_clusters=3,
                                  clustering_name='Visit Magnitude')

In [None]:
visit_magnitude_cluster.plot_distributions()

In [None]:
visit_magnitude_cluster.pca()

In [None]:
visit_magnitude_cluster.umap()

In [None]:
visit_magnitude_cluster.assign_cluster_names(
    ['low-IP+ER-visit-size', 'mid-IP+ER-visit-size', 'high-IP+ER-visit-size'])

visit_magnitude_cluster_member_counts = visit_magnitude_cluster.member_counts()
visit_magnitude_cluster_member_counts

## Cluster by Median Drug Exposure Start Date

In [None]:
drug_exposure_date_cluster = Cluster(site_attributes[site_attributes['src_hpo_id'] != 'ut_health_tyler'],
                              cluster_target_col='src_hpo_id',
                              cluster_attribute_cols='median_drug_exposure_years_elapsed',
                              n_clusters=4,
                                  clustering_name='Median Drug Exposure Start Date')

In [None]:
drug_exposure_date_cluster.plot_distributions()

In [None]:
drug_exposure_date_cluster.pca()

In [None]:
drug_exposure_date_cluster.umap()

In [None]:
drug_exposure_date_cluster.member_counts()

In [None]:
# drug_exposure_date_cluster.assign_cluster_names(
#     ['low-IP+ER-visit-size', 'mid-IP+ER-visit-size', 'high-IP+ER-visit-size'])

# drug_exposure_date_cluster_member_counts = drug_exposure_date_cluster.member_counts()
# drug_exposure_date_cluster_member_counts