In [1]:
import pandas as pd
import numpy as np
import os
import seaborn as sns
import json
from glob import glob
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, recall_score, precision_score, confusion_matrix, roc_auc_score
from lightgbm import LGBMClassifier, Booster
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import joblib



Вспомогательный блок функций

In [2]:
def get_correct_date_string(date):
    return pd.to_datetime(date, utc=True, errors = 'coerce').dt.strftime('%Y-%m-%d')

def time_diff_years(a, b):
    return (pd.to_datetime(a, errors = 'coerce') - pd.to_datetime(b, errors = 'coerce')) / np.timedelta64(1, 'Y')

def get_empty_list_nan(x):
    if isinstance(x, list):
        return x
    else:
        return []
    
try:
    from sympy.mpmath import mp
except ImportError:
    from mpmath import mp

def get_k_m_numbers_of_pi(k, m):
    mp.dps = m + 1
    return int(str(mp.pi)[k + 1:m + 1]) 

def get_correct_patient_id_df(df, patient_id_column="patient_id", number=0):
    df_new = df.copy()
    df_new[patient_id_column] -= number
    return df_new




def random_dates(start, end, n=10):
    start_u = pd.to_datetime(start).value//10**9
    end_u = pd.to_datetime(end).value//10**9
    return pd.to_datetime(np.random.randint(start_u, end_u, n), unit='s').strftime('%Y-%m-%d')

def get_filtered_data(df, repeat=10, date_start="2019-01-01", date_end="2020-01-01"):
    new_df = df.append([df] * repeat, ignore_index=True)
    new_df["date"] = random_dates(date_start, date_end, len(new_df))
    new_df = new_df[(new_df["death_dt"].isna()) | (new_df["death_dt"] > new_df["date"])]
    return new_df.reset_index(drop=True)

def get_filtered_data_1(df, repeat=10, date_start_1="2019-01-01", date_end_1="2020-01-01",
                        date_start_2="2019-01-01", date_end_2="2020-01-01", frac=0.5):
    new_df = df.append([df] * repeat, ignore_index=True).sample(frac=1).reset_index(drop=True)
    new_df_1 = new_df[:int(len(new_df) * frac)]
    new_df_1["date"] = date_end_1
    new_df_1 = new_df_1[(new_df_1["death_dt"].isna()) | (new_df_1["death_dt"] > new_df_1["date"])]
    new_df_2 = new_df[int(len(new_df) * frac):]
    new_df_2["date"] = date_end_2
    new_df_2 = new_df_2[(new_df_2["death_dt"].isna()) | (new_df_2["death_dt"] > new_df_2["date"])]
    print(len(new_df_1), len(new_df_2), len(new_df))
    return pd.concat([new_df_1, new_df_2]).reset_index(drop=True)

def get_filtered_data_2(df, repeat=10, date_start_1="2019-01-01", date_end_1="2020-01-01",
                        date_start_2="2019-01-01", date_end_2="2020-01-01", frac=0.5):
    new_df = df.append([df] * repeat, ignore_index=True).sample(frac=1).reset_index(drop=True)
    new_df_1 = new_df[:int(len(new_df) * frac)]
    new_df_1["date"] = random_dates(date_start_1, date_end_1, len(new_df_1))
    new_df_1 = new_df_1[(new_df_1["death_dt"].isna()) | (new_df_1["death_dt"] > new_df_1["date"])]
    new_df_2 = new_df[int(len(new_df) * frac):]
    new_df_2["date"] = date_end_2
    new_df_2 = new_df_2[(new_df_2["death_dt"].isna()) | (new_df_2["death_dt"] > new_df_2["date"])]
    print(len(new_df_1), len(new_df_2), len(new_df))
    return pd.concat([new_df_1, new_df_2]).reset_index(drop=True)


def get_group_date_feature(data, patient_id_column="patient_id", date_column="date",
                           value_column="some", group_column_name="loc",
                           filter_column=None, filter_value=None):
    if not pd.isna(filter_column) and not pd.isna(filter_value):
        data = data[data[filter_column] == filter_value]
    data = data.copy()
    data[group_column_name] = data.apply(lambda x: (x[date_column], x[value_column]), axis=1)
    return data.groupby(patient_id_column)[group_column_name].agg(lambda x: sorted(list(x)))


def get_target_data(data, folder=None, filename=None, filter_criteria=[{"main_diag_discharge": "I21"}], 
                    patient_id_column="patient_id", date_column="date"):
    data = data[(~data[patient_id_column].isna()) & (~data[date_column].isna())].copy()
    data[date_column] = get_correct_date_string(data[date_column])
    for loc_dict in filter_criteria:
        filters = []
        for column in loc_dict:
            if isinstance(loc_dict[column], str):
                filters.append(data[column].apply(lambda x: loc_dict[column] in str(x)))
            else:
                filters.append(data[column].apply(loc_dict[column]))
        data = data[np.logical_or.reduce(np.array(filters))]
    target = data.groupby(patient_id_column)[date_column].agg(lambda x: sorted(list(x))).reset_index()
    target = target.rename({patient_id_column: "patient_id", date_column: "infarction_dates"}, axis=1)
    if not pd.isna(folder) and not pd.isna(filename):
        target.to_csv(os.path.join(folder, filename), index=False)
    return target

def get_union_targets(targets):
    target = targets[0].copy()
    for new_target in targets[1:]:
        target = pd.merge(target, new_target, suffixes=("", "_new"),on="patient_id", how="outer")
        target["infarction_dates"] = target.apply(lambda x: sorted(list(set(get_empty_list_nan(x["infarction_dates"]) + get_empty_list_nan(x["infarction_dates_new"])))),
                                                  axis=1)
        target = target.drop(["infarction_dates_new"], axis=1)
    return target


def random_dates(start, end, n=10):
    start_u = pd.to_datetime(start).value//10**9
    end_u = pd.to_datetime(end).value//10**9
    return pd.to_datetime(np.random.randint(start_u, end_u, n), unit='s').strftime('%Y-%m-%d')

def get_filtered_data(df, repeat=10, date_start="2020-01-01", date_end="2021-01-01"):
    new_df = df.append([df] * (repeat - 1), ignore_index=True)
    new_df["date"] = random_dates(date_start, date_end, len(new_df))
    new_df = new_df[(new_df["death_dt"].isna()) | (new_df["death_dt"] > new_df["date"])]
    return new_df.reset_index(drop=True)

def get_all_before(x, date, before=True):
    if not isinstance(x, list):
        return None
    if isinstance(x[0], tuple):
        if before:
            res = [y for y in x if y[0] < date]
        else:
            res = [y for y in x if y[0] > date]
    else:
        if before:
            res = [y for y in x if y < date]
        else:
            res = [y for y in x if y > date]
    if not res:
        return None
    return res
    
def get_nearest_in_loc(x, value):
    values = list(filter(lambda y: value in str(y[1]), x))
    if not values:
        return []
    return max(values) 


def get_infarction(df, years=2):
    return df.apply(lambda x: int(any([0 < time_diff_years(y, x["date"]) < years for y in x["infarction_dates"]])) if isinstance(x["infarction_dates"], list) else 0, axis=1)

def get_infarction_regression(df):
    return df.apply(lambda x: min([time_diff_years(y, x["date"]) for y in x["infarction_dates"]]) if isinstance(x["infarction_dates"], list) else 9999, axis=1)

def get_nearest_value(df, column):
    return df.apply(lambda x: max(x[column])[1] if isinstance(x[column], list) else None, axis=1)

def get_nearest_diag_in(df, diag):
    return df.apply(lambda x: 1 if isinstance(x["diags"], list) and get_nearest_in_loc(x["diags"], diag) else 0, axis=1)

def get_nearest_diag_in_time_diff(df, diag):
    return df.apply(lambda x: time_diff_years(x["date"], get_nearest_in_loc(x["diags"], diag)[0]) if isinstance(x["diags"], list) and get_nearest_in_loc(x["diags"], diag) else None, axis=1)

def get_nearest_time_diff(df, column):
    return df.apply(lambda x: time_diff_years(x["date"], max(x[column])[0]) if isinstance(x[column], list) else None, axis=1)

def get_agg_value(df, column, agg_function):
    return df.apply(lambda x: agg_function([y[1] for y in x[column]]) if isinstance(x[column], list) else None, axis=1)



Считываем нужную часть данных

In [3]:
base_info_sample = pd.read_csv("data/target_less_15_nan_full_new_test_2.csv")

In [4]:
patient_id = set(base_info_sample["patient_id"])

In [5]:
base_info = pd.read_csv("data/data_less_nan_15.csv", usecols=["patient_id"])


In [6]:
index = set(base_info[base_info["patient_id"].isin(patient_id)].index + 1)

Объединяем данные

In [7]:
base_info = pd.read_csv("data/data_less_nan_15.csv", skiprows=lambda x: x not in index and x)


  base_info = pd.read_csv("data/data_less_nan_15.csv", skiprows=lambda x: x not in index and x)


In [8]:
base_info = base_info.drop(["age"], axis=1)

In [9]:
base_info = base_info.join(base_info_sample.set_index("patient_id"), on="patient_id")

Подсчитываем число пропусков

In [10]:
base_info["nan_count"] = base_info.isna().sum(axis=1)

Подсчитываем среднее число историй болезней с сердечно-сосудистыми заболеваниями

In [11]:
base_info[base_info["diags"].apply(lambda x: "I" in str(x))]["target"].mean()

0.021786763148930684

Копируем данные

In [12]:
base_info_not_nan = base_info.copy()

Предобрабатываем столбцы с анализами, используя разные подходы

In [13]:
COLUMNS = ['Антропометрия_ЧСС',
       'ЭКГ_qtc интервал', 'ЭКГ_r интервал', 'Курение_нет',
       'Анализ_Триглицериды', 'Антропометрия_Диастолическое давление',
       'Анализ_Креатинин в крови', 'Антропометрия_Вес',
       'Анализ_Общий холестерин', 'Анализ_АСТ', 'ЭКГ_qrs интервал',
       'Курение_бросил(а)', 'Анализ_КФК', 'ЭКГ_pq интервал', 'Анализ_ЛПНП',
       'Анализ_Белок', 'ЭКГ_гипертрофия', 'Анализ_Билирубин',
       'Антропометрия_Систолическое давление', 'Анализ_BNP', 'Анализ_ЛПВП',
       'Курение_в прошлом', 'ЭКГ_qt интервал', 'Анализ_АЛТ',
       'Антропометрия_Рост']

In [14]:
COLUMNS_LAST = []
for column in COLUMNS:
    print(column)
    COLUMNS_LAST.append(column + "_last")
    base_info_not_nan[column + "_last"] = base_info_not_nan[column].apply(lambda x: eval(x.replace("nan", "None"))[-1][1] if not pd.isna(x) else x)

Антропометрия_ЧСС
ЭКГ_qtc интервал
ЭКГ_r интервал
Курение_нет
Анализ_Триглицериды
Антропометрия_Диастолическое давление
Анализ_Креатинин в крови
Антропометрия_Вес
Анализ_Общий холестерин
Анализ_АСТ
ЭКГ_qrs интервал
Курение_бросил(а)
Анализ_КФК
ЭКГ_pq интервал
Анализ_ЛПНП
Анализ_Белок
ЭКГ_гипертрофия
Анализ_Билирубин
Антропометрия_Систолическое давление
Анализ_BNP
Анализ_ЛПВП
Курение_в прошлом
ЭКГ_qt интервал
Анализ_АЛТ
Антропометрия_Рост


In [15]:
COLUMNS_ADD = []
for column in COLUMNS:
    print(column)
    COLUMNS_ADD.append(column + "_last_time_diff")
    base_info_not_nan[column + "_last_time_diff"] = base_info_not_nan.apply(lambda x: time_diff_years(x["date"], eval(x[column].replace("nan", "None"))[-1][0]) if not pd.isna(x[column]) else x[column], axis=1)

Антропометрия_ЧСС
ЭКГ_qtc интервал
ЭКГ_r интервал
Курение_нет
Анализ_Триглицериды
Антропометрия_Диастолическое давление
Анализ_Креатинин в крови
Антропометрия_Вес
Анализ_Общий холестерин
Анализ_АСТ
ЭКГ_qrs интервал
Курение_бросил(а)
Анализ_КФК
ЭКГ_pq интервал
Анализ_ЛПНП
Анализ_Белок
ЭКГ_гипертрофия
Анализ_Билирубин
Антропометрия_Систолическое давление
Анализ_BNP
Анализ_ЛПВП
Курение_в прошлом
ЭКГ_qt интервал
Анализ_АЛТ
Антропометрия_Рост


Загружаем различную информацию об взаимосвязях заболеваний

In [16]:
with open('data/diags_imply_dict.pickle', 'rb') as handle:
    diags_imply_dict = pickle.load(handle)
    
with open('data/diags_imply_dict_scaled.pickle', 'rb') as handle:
    diags_imply_dict_scaled = pickle.load(handle)
    
with open('data/diags_imply_dict_weight.pickle', 'rb') as handle:
    diags_imply_dict_weight = pickle.load(handle)
    
with open('data/diags_imply_dict_weight_scaled.pickle', 'rb') as handle:
    diags_imply_dict_weight_scaled = pickle.load(handle)
    
with open('data/diags_imply_dict_rude.pickle', 'rb') as handle:
    diags_imply_dict_rude = pickle.load(handle)
    
with open('data/diags_imply_dict_scaled_rude.pickle', 'rb') as handle:
    diags_imply_dict_scaled_rude = pickle.load(handle)
    
with open('data/diags_imply_dict_weight_rude.pickle', 'rb') as handle:
    diags_imply_dict_weight_rude = pickle.load(handle)
    
with open('data/diags_imply_dict_weight_scaled_rude.pickle', 'rb') as handle:
    diags_imply_dict_weight_scaled_rude = pickle.load(handle)
    
with open('data/dict_imply_webgraph_10000_1000.pickle', 'rb') as handle:
    dict_imply_webgraph_10000_1000 = pickle.load(handle)

Предподготавливаем столбец с диагнозами

In [17]:
base_info_not_nan["diags"] = base_info_not_nan["diags"].apply(lambda x: eval(x) if not pd.isna(x) else x)

In [18]:
base_info_not_nan["diags_prepare"] = base_info_not_nan.apply(lambda x: [(time_diff_years(x["date"], y[0]), y[1]) for y in x["diags"]], axis=1)

In [19]:
base_info_not_nan["diags_prepare"] = base_info_not_nan.apply(lambda x: [(y[0], y[1].split(".")[0]) for y in x["diags_prepare"] if y[0] > 0], axis=1)

Готовим весовую функцию

In [20]:
def weight_f(x, alpha=1, eps=1/365):
    if pd.isna(x):
        return weight_f(0, alpha=alpha, eps=eps)
    return 1 / (x + eps) ** alpha

Готовим различные взвешивания с искючениями таргетных заболеваний

In [21]:
INFARCT_CODES = ["I21", "I22", "I23", "I63"]

def get_imply_weight(x, imply_dict, date, weight_method="const", weight_func=lambda x: 1 / (x + 1),
                     bad_codes=["I21", "I22", "I23", "I63"]):
    score = 0
    x = list(filter(lambda y: y[1] not in bad_codes, x))
    if not x:
        return None
    norm = 0
    for i in range(len(x)):
        if weight_method == "const":
            loc_norm = 1
        elif weight_method == "index":
            loc_norm = weight_func(i)
        else:
            loc_norm = weight_func(x[len(x) - 1 - i][0])
        norm += loc_norm
        for code in INFARCT_CODES:
            if (x[len(x) - 1 - i][1], code) in imply_dict:
                score += imply_dict[(x[len(x) - 1 - i][1], code)] / loc_norm
    score /= norm
    return score
                
            
def get_last_imply(x, imply_dict, bad_codes=["I21", "I22", "I23", "I63"]):
    x = list(filter(lambda y: y[1] not in bad_codes, x))
    return get_imply_weight(x[-1:], imply_dict, None, weight_method="const", weight_func=lambda x: 1 / (x + 1),
                     bad_codes=bad_codes)
    

Запускаем процесс получения фичей по взаимосвязям заболеваний

In [None]:
DIAGS_COLUMNS = []

dicts_dict = {"diags_imply_dict_rude": diags_imply_dict_rude,
              "diags_imply_dict_scaled_rude": diags_imply_dict_scaled_rude,
              "diags_imply_dict_weight_rude": diags_imply_dict_weight_rude,
              "diags_imply_dict_weight_scaled_rude": diags_imply_dict_weight_scaled_rude,
              "dict_imply_webgraph_10000_1000": dict_imply_webgraph_10000_1000}

weight_func_dict = {"weight_f_index": lambda x: weight_f(x, eps=1),
                    "weight_f_years": weight_f, 
                    "weight_f_alpha0.5_index": lambda x: weight_f(x, eps=1, alpha=0.5),
                    "weight_f_alpha0.5_years": lambda x: weight_f(x, alpha=0.5)}
methods = ["const", "index", "year"]
for imply_dict in dicts_dict:
    base_info_not_nan["{}_last".format(imply_dict)] = base_info_not_nan.apply(lambda x: get_last_imply(x["diags_prepare"], dicts_dict[imply_dict], bad_codes=["I21", "I22", "I23", "I63"]), axis=1)
    DIAGS_COLUMNS.append("{}_last".format(imply_dict))
    print("{}_last".format(imply_dict))
    for method in methods:
        if method == "const":
            base_info_not_nan[imply_dict] = base_info_not_nan.apply(lambda x: get_imply_weight(x["diags_prepare"], dicts_dict[imply_dict], x["date"], weight_method="const", weight_func=lambda x: 1 / (x + 1),
                                                                                               bad_codes=["I21", "I22", "I23", "I63"]), axis=1)
            DIAGS_COLUMNS.append(imply_dict)
            print(imply_dict)
        else:
            for weight_func in weight_func_dict:
                base_info_not_nan["{}_{}_{}".format(imply_dict, method, weight_func)] = base_info_not_nan.apply(lambda x: get_imply_weight(x["diags_prepare"], dicts_dict[imply_dict], x["date"], weight_method=method, weight_func=weight_func_dict[weight_func],
                                                                                                                                           bad_codes=["I21", "I22", "I23", "I63"]), axis=1)
                DIAGS_COLUMNS.append("{}_{}_{}".format(imply_dict, method, weight_func))
                print("{}_{}_{}".format(imply_dict, method, weight_func))
base_info_not_nan

Получаем дополнительные фичи

In [23]:
ADD = ["age", "weight", "imt"]

base_info_not_nan["diags_count"] = base_info_not_nan["diags"].apply(len)
ADD.append("diags_count")
base_info_not_nan["gender_int"] = base_info_not_nan["gender"] == "M"
ADD.append("gender_int")

Разбиваем на train и test согласно генерации датасета по датам

In [24]:
train = base_info_not_nan[base_info_not_nan["date"] < "2021-11-01"]
test = base_info_not_nan[base_info_not_nan["date"] == "2021-11-01"]

Создаем модель машинного обучения

In [25]:
clf = LGBMClassifier(n_estimators=500, random_state=42) 

In [26]:
clf.fit(train[COLUMNS_LAST + COLUMNS_ADD + ADD + DIAGS_COLUMNS ], train["target"])

LGBMClassifier(n_estimators=500, random_state=42)

In [27]:
predict_proba = clf.predict_proba(test[COLUMNS_LAST + COLUMNS_ADD + ADD + DIAGS_COLUMNS])[:, 1]

In [28]:
test["target"].mean()

0.016543796747639956

Смотрим на метрики качества в разрезе различных порогов

In [29]:
for i in range(10):
    tn, fp, fn, tp = confusion_matrix(test["target"], predict_proba > i * 0.1).ravel()
    print(f1_score(test["target"], predict_proba > i * 0.1), precision_score(test["target"], predict_proba > i * 0.1), recall_score(test["target"], predict_proba > i * 0.1))
    print(tn / (tn+fp), tp /(tp + fn))
    print("---")

0.03254910767361065 0.016543796747639956 1.0
0.0 1.0
---
0.21113852661261895 0.13484916704187302 0.4862012987012987
0.9475265234785696 0.4862012987012987
---
0.21756856931060045 0.20020463847203274 0.2382305194805195
0.9839903327553012 0.2382305194805195
---
0.15960912052117265 0.24098360655737705 0.11931818181818182
0.993678058904652 0.11931818181818182
---
0.10407689758037784 0.28390596745027125 0.06371753246753246
0.9972964485164603 0.06371753246753246
---
0.063003663003663 0.3233082706766917 0.0349025974025974
0.9987711129620274 0.0349025974025974
---
0.030104206869934386 0.30708661417322836 0.015827922077922076
0.9993992107814356 0.015827922077922076
---
0.014251781472684084 0.2903225806451613 0.007305194805194805
0.9996996053907178 0.007305194805194805
---
0.0056112224448897794 0.22580645161290322 0.002840909090909091
0.999836148394937 0.002840909090909091
---
0.0008084074373484237 0.1 0.00040584415584415587
0.9999385556481014 0.00040584415584415587
---


In [30]:
i = 1

tn, fp, fn, tp = confusion_matrix(test["target"], predict_proba > i * 0.1).ravel()
print(f1_score(test["target"], predict_proba > i * 0.1), precision_score(test["target"], predict_proba > i * 0.1), recall_score(test["target"], predict_proba > i * 0.1))
print(tn / (tn+fp), tp /(tp + fn))
print("---")

0.21113852661261895 0.13484916704187302 0.4862012987012987
0.9475265234785696 0.4862012987012987
---


Сохраняем результат предсказания на тестовых данных

In [31]:
pred = pd.DataFrame()
pred["patient_id"] = test["patient_id"]
pred["target"] = (predict_proba > 0.1).astype(int)
pred["p"] = predict_proba
pred.to_csv("data/pred_1year_1.csv", index=False)

Сохраняем модель

In [32]:
joblib.dump(clf, "data/model.pkl")

['data/model.pkl']