In [0]:
# Install packages
!pip install tqdm swifter openpyxl lifelines

In [0]:
# Importing necessary libraries for data manipulation, plotting, and statistical analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pandas.tseries.offsets import MonthEnd
import textwrap
import requests
import random
import time
import os
from tqdm.notebook import tqdm

# Machine Learning and statistical modeling libraries
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.metrics import roc_auc_score
from statsmodels.stats.outliers_influence import variance_inflation_factor
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import survival_difference_at_fixed_point_in_time_test
from lifelines.exceptions import ConvergenceError, StatisticalWarning

# Importing libraries for handling complex operations efficiently
import swifter
import openpyxl
import numpy.linalg as la

# Configuration to handle warnings in the code
import warnings
# Ignore specific statistical warnings from lifelines library
warnings.filterwarnings('ignore', category=StatisticalWarning)

In [0]:
def find_similar_atc_codes(atc_name, n_level, atc_data):
    """
    Given an ATC name, find all ATC codes where the ATC name is a substring of the ATC names in the data.
    The function then filters these codes to match a specific ATC level and have a length of 7,
    excluding the ATC code of the given name.

    Args:
        atc_name (str): A substring of the ATC name to match.
        n_level (int): The level of the ATC code to match, maximum is 4.
        atc_data (pd.DataFrame): The DataFrame containing ATC codes and names.

    Returns:
        list of str: A list of ATC names.
    """
    # Define the length for each ATC level
    atc_length = {1: 1, 2: 3, 3: 4, 4: 5}

    # Validate the n_level parameter
    if n_level not in atc_length:
        raise ValueError("Invalid ATC level. Please choose a level between 1 and 4.")

    # Find ATC codes where the given ATC name is a substring of the ATC names
    matching_atc_codes = atc_data[atc_data['atc_name'].str.lower() == atc_name.lower()]['atc_code']

    # Initialize an empty list for storing similar ATC names
    similar_atc_names = []

    for atc_code in matching_atc_codes:
        # Extract the first characters according to the ATC level
        first_chars = atc_code[:atc_length[n_level]]

        # Filter for ATC codes with the specified first characters and length of 7
        similar_atc_codes_df = atc_data[(atc_data['atc_code'].str.startswith(first_chars)) & 
                                        (atc_data['atc_code'].str.len() == 7) &
                                        (atc_data['atc_code'] != atc_code)]

        # Append the ATC names to the list
        similar_atc_names.extend(similar_atc_codes_df['atc_name'].tolist())
    
    return list(set(similar_atc_names))

def sep_icd_code(x):
    """
    Separates ICD (International Classification of Diseases) codes from a given string.
    
    The function expects a string containing ICD codes followed by descriptions,
    separated by new lines. It extracts and returns the first element (the ICD code)
    from each line.
    
    Parameters:
    - x (str): A string containing ICD codes and their descriptions, separated by new lines.
    
    Returns:
    - list: A list of ICD codes extracted from the input string.
    """
    
    # Split the input string into a list of substrings based on new line character.
    substrings = x.split('\n')
    
    # Extract the first element (ICD code) from each substring.
    # Each ICD code is followed by a space and then the description.
    first_elements = [substring.split(' ')[0] for substring in substrings]
    
    return first_elements

In [0]:
class SQLTable:
    """
    Query from RD database from VUMC
    """

    def __init__(self, table_name, items=None, category=None):
        """Initialize SQLTable with table name, items, and category."""
        self.table_name = table_name
        self.items = items
        self.category = category

    def format_query_items(self):
        """Format query items based on the category.

        - For 'drug' category, format items for drug source value.
        - For 'cancer' category, format items for ICD10 and ICD9 codes.
        - For 'phecode' category, format items for ICD10 and ICD9 codes.
        - Return None for other categories.
        """
        if self.category == 'drug':
            query_string = " OR\n                        ".join(
                f'LOWER(c.concept_name) LIKE "%{item.lower()}%"' for item in self.items
            )
            return query_string
        elif self.category == 'cancer':
            query_string_icd10 = " OR\n                                    ".join(
                f'concept_code LIKE "{item}%"' for item in self.items['icd10']
            )
            query_string_icd9 = " OR\n                                    ".join(
                f'concept_code LIKE "{item}%"' for item in self.items['icd9']
            )
            return query_string_icd10, query_string_icd9
        elif self.category == 'phecode':
            query_string_icd10 = ", ".join(f"'{item}'" for item in self.items['icd10'])
            query_string_icd9 = ", ".join(f"'{item}'" for item in self.items['icd9'])
            return query_string_icd10, query_string_icd9
        else:
            return None

    def drop_table(self):
        """Drop the SQL table if it exists."""
        if isinstance(self.table_name, list):
            spark.sql(f'DROP TABLE IF EXISTS {self.table_name[0]}')
        else:
            spark.sql(f'DROP TABLE IF EXISTS {self.table_name}')

    def create_table(self):
        """Create an SQL table based on the category.

        Different SQL queries are generated for different categories
        such as 'drug', 'cancer', 'record', 'demo', and 'phecode'.
        """
        self.drop_table()

        if self.category == 'drug':
            # For 'drug' category
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name} (
                    SELECT
                        DISTINCT de.person_id, de.drug_exposure_start_date 
                    FROM 
                        rd_omop_prod.drug_exposure de
                    JOIN
                        rd_omop_prod.drug_strength ds
                    ON 
                        de.drug_concept_id = ds.drug_concept_id
                    JOIN 
                        rd_omop_prod.concept c
                    ON 
                        c.concept_id = ds.ingredient_concept_id
                    WHERE (
                        {self.format_query_items()}
                    ) 
                )
            ''')
            spark.sql(sql_query)
        
        elif self.category == 'cancer':
            # For 'cancer' category
            string_icd10, string_icd9 = self.format_query_items()
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name} (
                    SELECT
                        person_id, MIN(condition_start_date) AS cancer_start_date
                    FROM
                        rd_omop_prod.condition_occurrence
                    WHERE 
                        condition_source_concept_id IN (
                            SELECT concept_id
                            FROM rd_omop_prod.concept
                            WHERE (
                                    ({string_icd10})
                                AND vocabulary_id = 'ICD10CM')
                            OR (
                                    ({string_icd9})
                                AND vocabulary_id = 'ICD9CM')
                        )
                    GROUP BY person_id
                )
            ''')
            spark.sql(sql_query)
        
        # Handle the 'record' category
        elif self.category == 'record':
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name[0]} (
                    SELECT
                        person_id, MIN(condition_start_date) AS record_start_date, 
                        MAX(condition_start_date) AS record_end_date
                    FROM
                        rd_omop_prod.condition_occurrence
                    WHERE person_id IN (
                        SELECT person_id FROM {self.table_name[1]}
                        UNION
                        SELECT person_id FROM {self.table_name[2]}
                    )
                    GROUP BY person_id
                )                       
            ''')

            spark.sql(sql_query)

        # Handle the 'demo' category
        elif self.category == 'demo':
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name[0]} (
                    WITH table1 AS (
                        SELECT 
                            person_id, 
                            birth_datetime,
                            gender_source_value, 
                            ethnicity_concept_id, 
                            race_concept_id
                        FROM 
                            rd_omop_prod.person
                        WHERE person_id IN (
                            SELECT person_id FROM {self.table_name[1]}
                            UNION
                            SELECT person_id FROM {self.table_name[2]}
                        )
                    ),
                    table2 as (
                        SELECT
                            table1.*, 
                            c.concept_name AS race_source_value
                        FROM table1
                        LEFT JOIN
                            rd_omop_prod.concept c ON table1.race_concept_id = c.concept_id
                    )
                    SELECT
                        table2.person_id, table2.birth_datetime, table2.gender_source_value, 
                        table2.race_source_value, c.concept_name AS ethnicity_source_value
                    FROM table2
                    LEFT JOIN
                        rd_omop_prod.concept c ON table2.ethnicity_concept_id = c.concept_id
                    WHERE table2.birth_datetime < GETDATE()
                )
            ''')
            spark.sql(sql_query)

        # Handle the 'death' category
        elif self.category == 'death':
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name[0]} (
                    SELECT 
                        person_id, death_date, cause_source_value, cause_source_concept_id
                    FROM 
                        rd_omop_prod.death
                    WHERE person_id IN (
                        SELECT person_id FROM {self.table_name[1]}
                        UNION
                        SELECT person_id FROM {self.table_name[2]}
                    )
                )                        
            ''')
            spark.sql(sql_query)
        
        # Handle the 'drug_count' category
        elif self.category == 'drug_count':
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name} (
                    WITH table1 AS (
                        SELECT 
                            ingredient_concept_id, 
                            COUNT(DISTINCT de.person_id) AS count_person_id
                        FROM rd_omop_prod.drug_exposure de
                        JOIN rd_omop_prod.drug_strength ds
                        ON de.drug_concept_id = ds.drug_concept_id
                        GROUP BY ingredient_concept_id
                    )

                    SELECT 
                        LOWER(c.concept_name) AS ingredient_name,
                        table1.count_person_id
                    FROM table1
                    JOIN rd_omop_prod.concept c
                    ON table1.ingredient_concept_id = c.concept_id
                )         
            ''')
            spark.sql(sql_query)
            
        # Handle the 'phecode' category
        elif self.category == 'phecode':
            string_icd10, string_icd9 = self.format_query_items()
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name}( 
                    WITH table1 AS (
                        SELECT 
                            DISTINCT phecode
                        FROM 
                            rd_omop_prod.x_phecodes_icd_mapping
                        WHERE 
                            (icd in ({string_icd9}) AND icd_flag = 9)
                        OR
                            ((icd in ({string_icd10}) AND icd_flag = 10))    
                    ),
                    table2 AS (
                        SELECT 
                            DISTINCT p.person_id, o.condition_start_date, o.condition_source_concept_id
                        FROM
                            rd_omop_prod.condition_occurrence o
                        JOIN
                            pat_cohort p
                        ON
                            o.person_id = p.person_id
                        WHERE
                            o.condition_start_date <= p.phecode_end_date
                    )

                    SELECT 
                        DISTINCT t2.person_id, t2.condition_start_date, cp.phecode
                    FROM 
                        table2 t2
                    JOIN 
                        (SELECT
                            c.concept_id, p.phecode
                        FROM 
                            rd_omop_prod.x_phecodes_icd_mapping p
                        JOIN 
                            table1 t1
                        ON 
                            t1.phecode = p.phecode
                        JOIN
                            rd_omop_prod.concept c
                        ON 
                            p.ICD = c.concept_code
                            AND 
                            ((c.vocabulary_id = 'ICD10CM' AND p.icd_flag = 10)
                                OR (c.vocabulary_id = 'ICD9CM' AND p.icd_flag = 9))) cp
                    ON 
                        t2.condition_source_concept_id = cp.concept_id
                )                     
            ''')
            spark.sql(sql_query)                            
        # Handle the 'bmi' category
        elif self.category == 'bmi':
            sql_query = textwrap.dedent(f'''
                CREATE TABLE {self.table_name}( 
                WITH table1 AS (
                SELECT 
                    DISTINCT p.person_id, 
                    m.value_as_number AS bmi, 
                    m.measurement_date,
                    ROW_NUMBER() OVER (PARTITION BY p.person_id ORDER BY m.measurement_date DESC) AS rank
                FROM 
                    rd_omop_prod.x_vs_bmi_clean b
                JOIN 
                    pat_cohort p
                ON 
                    b.person_id = p.person_id
                JOIN 
                    rd_omop_prod.measurement m
                ON 
                    b.measurement_id = m.measurement_id
                WHERE 
                    m.measurement_date <= p.phecode_end_date
                AND 
                    b.x_is_cleaned = 'Y'
                )

                SELECT
                    person_id, bmi
                FROM
                    table1
                WHERE
                    rank = 1
                )                     
            ''')
            spark.sql(sql_query)

        # Placeholder for other categories
        else:
            pass


In [0]:
class Dataset:
    def __init__(self, controlled_drugs, cancer_icd_codes):
        """Initialize the Dataset class with drug/cancer data.

        Args:
            controlled_drugs (list/dict): List or dictionary of controlled drugs.
            cancer_icd_codes (list/dict): List or dictionary of cancer ICD codes.
        """
        self.controlled_drugs = controlled_drugs
        self.cancer_icd_codes = cancer_icd_codes

    @staticmethod
    def spark_to_df(table_name):
        """Convert a Spark SQL table to a Pandas DataFrame.

        Args:
            table_name (str): Name of the table to convert.

        Returns:
            DataFrame: Pandas DataFrame containing the data from the Spark SQL table.
        """
        df = spark.sql(f'SELECT * FROM {table_name}').toPandas()
        return df

    def collect_data(self):
        """Collect data from various SQL tables and convert them to DataFrames.

        Creates tables for treated drugs, controlled drugs, and cancer ICD codes,
        and then converts these tables into Pandas DataFrames.

        Returns:
            tuple: A tuple of DataFrames for control group, case group, patient records, 
                   cancer patients, and demographic data.
        """
        # Create SQL tables for the different datasets
        SQLTable(table_name='pat_group1', items=self.controlled_drugs, category='drug').create_table()
        SQLTable(table_name='pat_cancer', items=self.cancer_icd_codes, category='cancer').create_table()
        SQLTable(table_name=['pat_record', 'pat_group1', 'pat_group2'], category='record').create_table()
        SQLTable(table_name=['pat_demo', 'pat_group1', 'pat_group2'], category='demo').create_table()
        

        # Convert the SQL tables to Pandas DataFrames
        pat_control = self.spark_to_df('pat_group1')
        pat_record = self.spark_to_df('pat_record')
        pat_cancer = self.spark_to_df('pat_cancer')
        pat_demo = self.spark_to_df('pat_demo')

        return pat_control, pat_record, pat_cancer, pat_demo


In [0]:
class DataProcessor:
    def __init__(self, pat_demo, pat_case, pat_control, pat_cancer, pat_record):
        """Initializes the DataProcessor class with necessary datasets."""
        self.pat_demo = pat_demo
        self.pat_case = pat_case
        self.pat_control = pat_control
        self.pat_cancer = pat_cancer
        self.pat_record = pat_record
    
    @staticmethod
    def aggregate_drug_record(df):
        """
        - Excludes false records based on drug exposure start date.
        - Aggregates data to find the first drug exposure dates and count of record within 36 months.

        Args:
            df (DataFrame): The patient record DataFrame to preprocess.

        Returns:
            DataFrame: The preprocessed patient record DataFrame.
        """
        df['drug_exposure_start_date'] = pd.to_datetime(df['drug_exposure_start_date'])
        df = df[df['drug_exposure_start_date'] > datetime.strptime('1990-01-01', '%Y-%m-%d')]
        
        if len(df) > 1:

            # Find the minimum drug_exposure_start_date for each person_id
            min_dates = df.groupby('person_id')['drug_exposure_start_date'].min().reset_index(name='drug_start_date')

            # Join this information back to the original DataFrame
            df_with_start = pd.merge(df, min_dates, on='person_id')

            # Define a function to count records within 36 months from the drug_start_date
            def count_within_36_months(group):
                # Calculate the end date as 36 months after the drug_start_date
                group['end_date'] = group['drug_start_date'] + pd.DateOffset(months=36)
                # Count how many records fall within the start and end date
                return group[(group['drug_exposure_start_date'] >= group['drug_start_date']) & 
                            (group['drug_exposure_start_date'] <= group['end_date'])].shape[0]

            # Apply the function to each group and reset the index
            counts = df_with_start.groupby('person_id').apply(count_within_36_months).reset_index(name='record_counts')

            # Combine the counts with the min_dates to have the final DataFrame
            result_df = pd.merge(min_dates, counts, on='person_id')

            return result_df[result_df['record_counts'] > 1].reset_index()
        else:
            return df

    def merge_data(self):
        """Merges various datasets and performs additional preprocessing for the final analysis."""
        # Sets flags in the demographic data to indicate case, control, and cancer status.
        self.pat_demo['is_case'] = self.pat_demo.person_id.isin(self.pat_case.person_id)
        self.pat_demo['is_control'] = self.pat_demo.person_id.isin(self.pat_control.person_id)
        self.pat_demo['is_cancer'] = self.pat_demo.person_id.isin(self.pat_cancer.person_id)

        # Exclude patients who took both drugs
        df = self.pat_demo[(~self.pat_demo['is_case'] & self.pat_demo['is_control']) | (self.pat_demo['is_case'] & ~self.pat_demo['is_control'])]

        # Merging data
        df = pd.merge(df, pd.concat([self.pat_case, self.pat_control]), on='person_id', how='inner')
        df = pd.merge(df, self.pat_cancer, on='person_id', how='left')
        df = pd.merge(df, self.pat_record, on='person_id', how='left')
        df['race_source_value'] = df['race_source_value'].apply(lambda x: self.map_race_value(x))
        
        # Drop patient with no icd code record
        df = df.dropna(subset='record_start_date')
        
        return df

    @staticmethod
    def map_race_value(x):
        """Maps race values to categories."""
        if x not in ['White', 'Black or African American', 'No matching concept']:
            return 'Other'
        else:
            return x
        
    def calculate_variables(self, df):
        """Calculates various variables for analysis."""

        df['drug_start_date'] = pd.to_datetime(df['drug_start_date'])
        df['cancer_start_date'] = pd.to_datetime(df['cancer_start_date'])
        df['record_start_date'] = pd.to_datetime(df['record_start_date'])
        df['record_end_date'] = pd.to_datetime(df['record_end_date'])

        # Baseline period calculations
        df['drug_baseline_period'] = self.calculate_time_period(df, 'drug_start_date', 'record_start_date')
        df['cancer_baseline_period'] = self.calculate_time_period(df, 'cancer_start_date', 'record_start_date')
        df['record_baseline_period'] = self.calculate_time_period(df, 'record_end_date', 'record_start_date')
        df['time_to_event'] = self.calculate_time_period(df, 'cancer_start_date', 'drug_start_date')
        
        # Require at least 12 months record before first drug exposure or diagnose the cancer as well as at least 12 months record length 
        df = df[(df['drug_baseline_period'] >= 12) & (df['record_baseline_period'] >= 12)]
        df = df[df.apply(lambda x: x.cancer_baseline_period >= 15 if x.is_cancer else True, axis=1)]

        # Calculates ages at different time points.
        df['birth_date'] = pd.to_datetime(df.birth_datetime)
        df['cancer_age'] = self.calculate_time_period(df, 'cancer_start_date', 'birth_date')
        df['record_end_age'] = self.calculate_time_period(df, 'record_end_date', 'birth_date')

        # Require cancer age >= 40 or last record age >= 40
        df['age'] = df.apply(lambda x: x.cancer_age if x.is_cancer else x.record_end_age, axis=1)
        df = df[df['age'] >= 40 * 12]
        df['phecode_end_date'] = df.apply(lambda x: x.cancer_start_date - pd.DateOffset(months=3) if x.is_cancer else x.record_end_date, axis=1)
        df['phecode_end_date'] = df['phecode_end_date'].dt.date
        
        return df

    @staticmethod
    def calculate_time_period(x, column_name_1, column_name_2):
        """Calculates the time period for two given columns."""
        period = (x[column_name_1] - x[column_name_2])
        return period.apply(lambda x: x.days / 30.4375 if not pd.isna(x) else np.nan)


In [0]:
class PropensityScoreLR:
    def __init__(self, confounder, treatment, random_seed):
        """
        Initializes the PropensityScoreLR with data and treatment label column.

        Args:
            confounder (np.array): The confounder variables.
            treatment (np.array): The treatment labels.
            random_seed (int): The random state used in cross validation split and logistic regression
        """

        self.confounder = confounder
        self.treatment = treatment
        self.random_seed = random_seed

    @staticmethod
    def truncate(array):
        """ Truncates values in an array to the specified lower and upper percentiles. """
        # lower_value = np.percentile(array, 1)
        # upper_value = np.percentile(array, 99)
        # return np.clip(array, lower_value, upper_value)
        return np.clip(array, a_min=5e-06, a_max=5e1)
        
    
    @staticmethod
    def weighted_mean(x, w):
        """ Calculate weighted mean. """
        return np.sum(np.multiply(x, w), axis=0) / w.sum()

    @staticmethod
    def weighted_var(x, w):
        """ Calculate weighted variance. """
        m_w = PropensityScoreLR.weighted_mean(x, w)
        nw, nsw = w.sum(), (w ** 2).sum()
        var = np.multiply((x - m_w) ** 2, w)
        return np.sum(var, axis=0) * (nw / (nw ** 2 - nsw))

    @staticmethod
    def cal_IPTW(y, ps):
        """ Calculate Inverse Probability of Treatment Weights. """
        ones_idx, zeros_idx = np.where(y == True), np.where(y == False)
        p_T = len(ones_idx[0]) / (len(ones_idx[0]) + len(zeros_idx[0]))
        treated_w = p_T / ps[ones_idx]
        controlled_w = (1 - p_T) / (1. - ps[zeros_idx])
        treated_w, controlled_w = PropensityScoreLR.truncate(treated_w), PropensityScoreLR.truncate(controlled_w)
        return np.reshape(treated_w, (len(treated_w), 1)), np.reshape(controlled_w, (len(controlled_w), 1))

    @staticmethod
    def cal_SMD(X, y, propensity_score):
        """ Calculate Standardized Mean Differences. """
        ones_idx, zeros_idx = np.where(y == True), np.where(y == False)
        treated_w, controlled_w = PropensityScoreLR.cal_IPTW(y, propensity_score)
        treated_X, controlled_X = X[ones_idx], X[zeros_idx]

        treated_X_w_mu = PropensityScoreLR.weighted_mean(treated_X, treated_w)
        controlled_X_w_mu = PropensityScoreLR.weighted_mean(controlled_X, controlled_w)
        treated_X_w_var = PropensityScoreLR.weighted_var(treated_X, treated_w)
        controlled_X_w_var = PropensityScoreLR.weighted_var(controlled_X, controlled_w)
        VAR = np.sqrt((treated_X_w_var + controlled_X_w_var) / 2)
        SMD = np.divide(np.abs(treated_X_w_mu - controlled_X_w_mu), VAR, 
        out=np.zeros_like(treated_X_w_mu), where=(VAR!=0))
        return SMD

    def n_weight(self, estimator, X, y):
        """ Custom evaluation function: count of unbalanced weighted covariate. """

        propensity_score = estimator.predict_proba(self.confounder)[:, 1]
        SMD = self.cal_SMD(self.confounder, self.treatment, propensity_score)
        ps = estimator.predict_proba(X)[:, 1]
        auc = roc_auc_score(y, ps)
        return len(np.where(SMD <= 0.1)[0]) + auc

    def fit_model(self):
        """ Fit the logistic regression model using GridSearchCV. """
        model = LogisticRegression()
        cv = KFold(n_splits=10, shuffle=True, random_state=self.random_seed)
        param_grid = {
            'C': [0.005, 0.01, 0.05, 0.1, 0.5],
            'penalty': ['l1','l2'],
            'random_state': [self.random_seed],
            'solver':['liblinear'],
            'max_iter':[1000],
        }
        grid_search = GridSearchCV(model, param_grid, 
                                   scoring=self.n_weight,
                                   cv=cv, n_jobs=-1)
        grid_search.fit(self.confounder, self.treatment)
        return grid_search.best_estimator_, grid_search.best_params_

In [0]:

# Define a function to record and write timing information
def record_and_write_timing(block_name, start_time):
    end_time = time.time()
    execution_time = end_time - start_time
    with open("/Workspace/Users/qingyuan.song.1@vumc.org/Protein Cancer Risk/timing_info.txt", "a") as file:
        file.write(f"{block_name}: {execution_time} seconds\n")

def save_output_checkpoint(output_df, output_file_path):
    output_df.to_csv(output_file_path, index=False)

In [0]:
# Define the root folder
root_folder = '/Workspace/Users/qingyuan.song.1@vumc.org/Protein Cancer Risk'
# Drug information
drug_file_path = os.path.join(root_folder, 'protein_supplementary_tables_20240227.xlsx')
# Map between ATC code and drug name
atc_file_path = os.path.join(root_folder, 'WHO ATC-DDD 2021-12-03.csv')
#confounders icd code for each type of cancer
confounder_icd_code_file_path = os.path.join(root_folder, 'Cancer_Confounders_with_ICD.csv')

# Outputs
date_time_suffix = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_file_path = os.path.join(root_folder, f"output_{date_time_suffix}.csv")
km_curve_file_path = os.path.join(root_folder, f"KM_curve_{date_time_suffix}.csv" )


n_level = 2
random_seed = 37

drug_table = pd.read_excel(drug_file_path, sheet_name='S16', skiprows=1)
drug_table = drug_table[(drug_table['MaxPhase'] == 4) & (drug_table['IS_APPROVED_AntiCancerDrug'] == False)].reset_index(drop=True)

atc_df = pd.read_csv(atc_file_path)
confounder_table = pd.read_csv(confounder_icd_code_file_path)

cancer_icd_codes_dict = {'BRCA':{'icd10':['C50'], 'icd9':['174', '175']}, 'COADREAD':{'icd10':['C18', 'C19', 'C20'], 'icd9':['153', '154']}, 'PRAD':{'icd10':['C61'], 'icd9':['185']}} 

km_df = pd.DataFrame(columns=['Cancer type',
                              'Treated drug',
                              'Controlled drug',
                              'Timeline',
                              'Treated',
                              'Treated_lower_0.95',
                              'Treated_upper_0.95',
                              'Controlled',
                              'Controlled_lower_0.95',
                              'Controlled_upper_0.95'])

save_output_checkpoint(km_df, km_curve_file_path)

output_df = pd.DataFrame(columns=['cancer type',
                                  'treated drug generic name',
                                  'controlled drug generic name', 
                                  'unbalanced_covariate_percentage',
                                  'weighted KM survival difference',
                                  'weighted KM ATE',
                                  'weighted KM p value',
                                  'weighted CoxPH harzard ratio',
                                  'weighted CoxPH standard estimation',
                                  'weighted CoxPH confidence intervals lower',
                                  'weighted CoxPH confidence intervals upper',
                                  'weighted CoxPH P value',
                                  'treated group size', 
                                  'controlled group size',
                                  'treated cancer group size',
                                  'controlled cancer group size',
                                  'LR best hyperparameter'])

save_output_checkpoint(output_df, output_file_path)

SQLTable(table_name='pat_ingredient', category='drug_count').create_table()
pat_ingredient = Dataset.spark_to_df('pat_ingredient')
valid_drug = pat_ingredient.loc[pat_ingredient['count_person_id'] >= 500, 'ingredient_name']

In [0]:
for index, row in tqdm(drug_table.iterrows(), total=drug_table.shape[0], desc='treated loop'):
    cancer_icd_codes = cancer_icd_codes_dict[row['Cancer']]
    # Start timing for treated drug
    start_time_block1 = time.time()

    treated_drug_raw = row['MolecularBrandName']
    treated_drug = treated_drug_raw.split(' ')[0]
    if valid_drug.str.contains(treated_drug, case=False).sum() == 0:
        continue

    candidate_controlled_drug = find_similar_atc_codes(treated_drug, n_level, atc_df)
    if treated_drug_raw != treated_drug:
        add_candidate_controlled_drug = find_similar_atc_codes(treated_drug_raw, n_level, atc_df)
        candidate_controlled_drug.extend(add_candidate_controlled_drug)
    candidate_controlled_drug = [drug for drug in candidate_controlled_drug if valid_drug.str.contains(drug, case=False).sum() > 0]
    if len(candidate_controlled_drug) == 0:
        continue

    print('treated drug:', treated_drug_raw)
    SQLTable(table_name='pat_group2', items=[treated_drug], category='drug').create_table()
    pat_case = Dataset.spark_to_df('pat_group2')
    pat_case = DataProcessor.aggregate_drug_record(pat_case)

    record_and_write_timing("Treated patient SQL:\t", start_time_block1)

    for candidate_drug in tqdm(candidate_controlled_drug, desc='controlled loop', leave=False):

        print(candidate_drug)
        # Start timing for controlled drug
        start_time_block2 = time.time()
        controlled_drug = [candidate_drug]

        pat_control, pat_record, pat_cancer, pat_demo = Dataset(controlled_drug, cancer_icd_codes).collect_data()
        pat_control = DataProcessor.aggregate_drug_record(pat_control)

        record_and_write_timing("Controlled patient SQL:\t", start_time_block2)

        # Start timing for processing SQL
        start_time_block3 = time.time()
        dp = DataProcessor(pat_demo, pat_case, pat_control, pat_cancer, pat_record)
        df = dp.merge_data()
        df = dp.calculate_variables(df)

        df_ = df[(df['time_to_event'] > 0) | (df['time_to_event'].isna())].copy()
        df_['time_to_event'] = df_['time_to_event'].fillna(DataProcessor.calculate_time_period(df_, 'record_end_date', 'drug_start_date'))
        
        df_['gender'] = df_['gender_source_value'].apply(lambda x: int(x == 'M'))

        if row['Cancer'] == 'BRCA':
            df_ = df_[df_['gender'] == 0].drop('gender', axis=1)
        elif row['Cancer'] == 'PRAD':
            df_ = df_[df_['gender'] == 1].drop('gender', axis=1)
        else:
            pass

        if (df_['is_case'].sum() < 250) or (df_['is_control'].sum() < 250):
            continue

        df_['ethnicity_Hispanic or Latino'] = (df_['ethnicity_source_value'] == 'Hispanic or Latino').astype(int)
        pat_race_d = pd.get_dummies(df_[['person_id','race_source_value']], columns=['race_source_value'], prefix='race')
        pat_race_d = pat_race_d[['person_id', 'race_Black or African American', 'race_White']]

        record_and_write_timing("Processing SQL:\t", start_time_block3)
        
        # Start timing for phecode SQL
        start_time_block4 = time.time()

        spark_df = spark.createDataFrame(df_[['person_id', 'phecode_end_date']])      
        spark_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable("pat_cohort")

        icd9_codes = confounder_table[(confounder_table['Cancer'] == row['Cancer']) | (confounder_table['Cancer'] == 'G')].ICD9.apply(sep_icd_code).to_list()
        icd10_codes = confounder_table[(confounder_table['Cancer'] == row['Cancer']) | (confounder_table['Cancer'] == 'G')].ICD10.apply(sep_icd_code).to_list()
        icd9_codes = list(set([c for l in icd9_codes for c in l]))
        icd10_codes = list(set([c for l in icd10_codes for c in l]))

        SQLTable(table_name='pat_phecode', items={'icd9': icd9_codes, 'icd10': icd10_codes}, category='phecode').create_table()
        pat_phecode = Dataset.spark_to_df('pat_phecode')

        pat_phecode['phecode'] = pat_phecode.phecode.apply(lambda x: x.split('.')[0])
        pat_phecode = pat_phecode[['person_id','phecode']].drop_duplicates()

        pat_phecode_d = pd.get_dummies(pat_phecode, columns = ['phecode'])
        pat_phecode_d = pat_phecode_d.groupby('person_id').agg('sum')

        SQLTable(table_name='pat_bmi', category='bmi').create_table()
        pat_bmi = Dataset.spark_to_df('pat_bmi')
        pat_bmi['bmi'] = pat_bmi['bmi'].astype(float)

        if (row['Cancer'] == 'BRCA') or (row['Cancer'] == 'PRAD'):
            data = pd.merge(df_[['person_id','age', 'ethnicity_Hispanic or Latino']], pat_race_d, how='inner', on='person_id')
        else:
            data = pd.merge(df_[['person_id','age','gender', 'ethnicity_Hispanic or Latino']], pat_race_d, how='inner', on='person_id')
        data = pd.merge(data, pat_phecode_d.reset_index(), how='inner', on='person_id')
        data = pd.merge(data, pat_bmi, how='inner', on='person_id')

        data['age'] = (data['age'] - data['age'].min()) / (data['age'].max() - data['age'].min())
        data['bmi'] = (data['bmi'] - data['bmi'].min()) / (data['bmi'].max() - data['bmi'].min())
        
        record_and_write_timing("Phecode SQL:\t", start_time_block4)

        # Start timing for IPTW LR
        start_time_block5 = time.time()

        data = pd.merge(data, df_[['person_id','is_case','is_cancer','time_to_event']], how='inner', on='person_id')
        data = data.dropna().reset_index(drop=True)
        
        confounder = data.iloc[:, 1:-3].to_numpy()
        treatment = data['is_case'].to_numpy()

        lr = PropensityScoreLR(confounder, treatment, random_seed)
        best_model, best_params = lr.fit_model()
        propensity_score = best_model.predict_proba(confounder)[:, 1]
        smd = PropensityScoreLR.cal_SMD(confounder, treatment, propensity_score)
        p_unbalanced = len(np.where(smd > 0.1)[0]) / len(smd)

        record_and_write_timing("IPTW LR:\t", start_time_block5)

        # Start timing for survival analysis
        start_time_block6 = time.time()

        duration = data['time_to_event'].to_numpy()

        ones_idx, zeros_idx = np.where(treatment == True), np.where(treatment == False)
        treated_w, controlled_w = PropensityScoreLR.cal_IPTW(treatment, propensity_score)
        treated_duration, controlled_duration = duration[ones_idx], duration[zeros_idx]

        cph = CoxPHFitter()

        cancer = data['is_cancer'].to_numpy()
        treated_cancer, controlled_cancer = cancer[ones_idx], cancer[zeros_idx]

        treated_prob = np.matmul(treated_cancer, treated_w) / np.sum(treated_w)
        controlled_prob = np.matmul(controlled_cancer, controlled_w) / np.sum(controlled_w)

        # CoxPH
        weight = np.zeros(len(cancer))
        weight[ones_idx] = treated_w.squeeze()
        weight[zeros_idx] = controlled_w.squeeze()
        cox_data = pd.DataFrame({'T': duration, 'event': cancer, 'treatment': treatment, 'weights':weight})
        cox_data = pd.concat([cox_data, data.iloc[:, 1:-3].iloc[:, np.where(smd > 0.1)[0]]], axis=1)

        try:
            cph.fit(cox_data, duration_col='T', event_col='event', weights_col='weights', robust=True)
        except Exception as e:
            try:
                # Calculate VIF
                vif = pd.DataFrame()
                vif["Feature"] = cox_data.columns
                vif["VIF"] = [variance_inflation_factor(cox_data, i) for i in range(cox_data.shape[1])]

                # Drop columns with VIF > 10
                high_vif_features = vif[vif["VIF"] > 10]["Feature"]
                high_vif_features = [ft for ft in high_vif_features if ft not in ['T','event','weights']]
                cox_data_filtered = cox_data.drop(columns=high_vif_features)
                cph.fit(cox_data_filtered, duration_col='T', event_col='event', weights_col='weights', robust=True)
            except Exception as e:
                # Handle Exception, general
                print(f"Exception encountered: {e}. Continuing with next dataset.")
                continue
            
        kmf_A = KaplanMeierFitter()
        kmf_B = KaplanMeierFitter()

        # Kaplan Meier
        treated_kmf = kmf_A.fit(treated_duration, treated_cancer, label="Treated", weights=treated_w)
        controlled_kmf = kmf_B.fit(controlled_duration, controlled_cancer, label="Controlled", weights=controlled_w)
        results_w = survival_difference_at_fixed_point_in_time_test([36, 60], treated_kmf, controlled_kmf)
        survival_treated = treated_kmf.predict([36, 60]).to_numpy()
        survival_controlled = controlled_kmf.predict([36, 60]).to_numpy()
        ate_w = survival_treated - survival_controlled
        
        # Extract the survival function data for both models
        treated_survival_df = treated_kmf.survival_function_.rename(columns={'KM_estimate': f'KM_estimate_treated'})
        controlled_survival_df = controlled_kmf.survival_function_.rename(columns={'KM_estimate': f'KM_estimate_controlled'})
        
        # Optionally, include confidence intervals by joining them into the respective DataFrames
        treated_survival_df = treated_survival_df.join(treated_kmf.confidence_interval_)
        controlled_survival_df = controlled_survival_df.join(controlled_kmf.confidence_interval_)

        # Combine the two DataFrames
        combined_df = treated_survival_df.join(controlled_survival_df, how='outer').reset_index()
        combined_df['cancer type'] = row['Cancer']
        combined_df['treated drug'] = treated_drug
        combined_df['controlled drug'] = candidate_drug
        
        km_df = pd.concat([km_df, combined_df], ignore_index=True)

        save_output_checkpoint(km_df, km_curve_file_path)

        record_and_write_timing("Survival analysis\t", start_time_block6)

        new_row = {'cancer type': row['Cancer'],
                'treated drug generic name': treated_drug_raw,
                'controlled drug generic name': candidate_drug, 
                'unbalanced_covariate_percentage': p_unbalanced, 
                'weighted KM survival difference': results_w.test_statistic.to_dict(),
                'weighted KM ATE': ate_w,
                'weighted KM p value': results_w.p_value,
                'weighted CoxPH harzard ratio': cph.hazard_ratios_['treatment'],
                'weighted CoxPH standard estimation': cph.summary.loc['treatment','se(coef)'],
                'weighted CoxPH confidence intervals lower': cph.summary.loc['treatment','exp(coef) lower 95%'],
                'weighted CoxPH confidence intervals upper': cph.summary.loc['treatment','exp(coef) upper 95%'],
                'weighted CoxPH P value': cph.summary.loc['treatment', 'p'],
                'treated group size': data['is_case'].sum(), 
                'controlled group size':(~data['is_case']).sum(),
                'treated cancer group size': treated_cancer.sum(),
                'controlled cancer group size':controlled_cancer.sum(),
                'LR best hyperparameter': best_params,
        }
        output_df = output_df.append(new_row, ignore_index=True)
        save_output_checkpoint(output_df, output_file_path)