In [3]:
import numpy as np
import pandas as pd
import os
from lifelines import KaplanMeierFitter

def read_file(filename, columns=None, **kwargs):
    load_extension = os.path.splitext(filename)[-1]
    if load_extension == ".parquet":
        return pd.read_parquet(filename, columns=columns, **kwargs)
    elif load_extension == ".csv":
        return pd.read_csv(filename, usecols=columns, **kwargs)

def censoring_weights(df, model_type = 'KM'):

    if model_type == 'KM':
        censoring_model = KaplanMeierFitter()
    else:
        raise ValueError("censoring_model not defined")
    
    censoring_model.fit(df.query('is_train==1').event_time, 1.0*~df.query('is_train==1').event_indicator)
    
    weights = 1 / censoring_model.survival_function_at_times(df.event_time_10yr.values - 1e-5)
    weights_dict = dict(zip(df.index.values, weights.values))
    return weights_dict

def get_censoring(df, by_group=True, model_type='KM'):
    
    if by_group:
        weight_dict = {}
        for group in [1, 2, 3, 4]:
            group_df = df.query('grp==@group')
            group_weights_dict = censoring_weights(group_df, model_type)
            weight_dict.update(group_weights_dict)
            
    else:
        weight_dict = censoring_weights(cohort, censoring_model_type)

    weights = pd.Series(weight_dict, name='weights') 
    return weights
    
cohort_path = '/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction/all_cohorts.csv'



val_fold_id = '1'
censoring_model_type='KM'
censoring_by_group = True

cohort = read_file(cohort_path)
cohort = cohort.assign(is_train = lambda x: np.where((x.fold_id != val_fold_id) & (x.fold_id != "test"), 1, 0))
del cohort['weights']

all_weights = get_censoring(cohort, by_group = censoring_by_group, model_type = censoring_model_type)
cohort = cohort.join(all_weights)

In [5]:
# add to train_model
# run training again, with censoring done by group, with larger range of lambda
# do the same but with MMD

In [None]:
    # train censoring model
    