In [None]:
# Read args
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_folder", type=str, default="ehr", 
    help="data folder name")
parser.add_argument("-a", "--aggregator", type=str, default="transformer", 
    help="aggregator name", choices=['conflation', 'sanity_check', 'sanity_check_transformer', 'transformer'])
parser.add_argument("-r", "--resume", type=str, default = "kmembert-base", 
    help="result folder in which the saved checkpoint will be reused")
parser.add_argument("-e", "--epochs", type=int, default=2, 
    help="number of epochs")
parser.add_argument("-nr", "--nrows", type=int, default=None, 
    help="maximum number of samples for training and validation")
parser.add_argument("-k", "--print_every_k_batch", type=int, default=1, 
    help="prints training loss every k batch")
parser.add_argument("-dt", "--days_threshold", type=int, default=365, 
    help="days threshold to convert into classification task")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4, 
    help="model learning rate")
parser.add_argument("-wg", "--weight_decay", type=float, default=0, 
    help="the weight decay for L2 regularization")
parser.add_argument("-p", "--patience", type=int, default=4, 
    help="number of decreasing accuracy epochs to stop the training")
parser.add_argument("-me", "--max_ehrs", type=int, default=4, 
    help="maximum nusmber of ehrs to be used for multi ehrs prediction")
parser.add_argument("-nh", "--nhead", type=int, default=8, 
    help="number of transformer heads")
parser.add_argument("-nl", "--num_layers", type=int, default=4, 
    help="number of transformer layers")
parser.add_argument("-od", "--out_dim", type=int, default=2, 
    help="transformer out_dim (1 regression or 2 density)")
parser.add_argument("-td", "--time_dim", type=int, default=8, 
    help="transformer time_dim")
args = parser.parse_args("")

from kmembert.utils import Config
from kmembert.models import TransformerAggregator
from kmembert.dataset import PredictionsDataset
from torch.utils.data import DataLoader
from collections import defaultdict
from kmembert.dataset import PredictionsDataset
from kmembert.models import TransformerAggregator
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score, roc_curve, r2_score, mean_squared_error
from kmembert.utils import pretty_time, printc, create_session, save_json, get_label_threshold, get_error, time_survival_to_label, collate_fn
from kmembert.utils import create_session, get_label_threshold, collate_fn
from sklearn.metrics import confusion_matrix
from time import time

import pandas as pd
import seaborn as sns
import torch
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import torch
import sys

In [None]:
# Read data
import pandas as pd
test = pd.read_csv("test_rs1_ind.csv")
test.head(3)

In [None]:
path_dataset, path_result, device, config = create_session(args)

assert (768 + args.time_dim) % args.nhead == 0, f'd_model (i.e. 768 + time_dim) must be divisible by nhead. Found time_dim {args.time_dim} and nhead {args.nhead}'

config.label_threshold = get_label_threshold(config, path_dataset)

dataset = PredictionsDataset(path_dataset, config, output_hidden_states=True, device=device, train=False)
loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [None]:
for i, (*data, labels, raw) in enumerate(loader):
    data = [[data[0][0]], torch.tensor(data[0][1], dtype=torch.float32)]
    print(type(data[1]))

In [None]:
for i, (*data, labels) in enumerate(loader):
    print(data[1].dtype)

In [None]:
torch.tensor(np.array([22, 15,  0,  0]), dtype=float)

In [None]:
# Fonction to convert dates to time survival
from datetime import datetime
def strDate_to_days_cr(row, date_format = "%Y-%m-%d"):
    '''
        When there's no FLAG_DECES column in the df
    '''
    a = datetime.strptime(row['Date deces'], date_format)
    b = datetime.strptime(row['Date cr'], date_format)
    val = (a-b).days
    return val
def time_survival_to_label(time_survival, mean_time_survival=800):
    """
    Transforms times of survival into uniform labels in ]0,1[
    """
    return 1 - np.exp(-time_survival/mean_time_survival)

In [None]:
test = pd.read_csv("data\\ehr\\test.csv")
test['Date cr'] = test['Date cr'].apply(lambda x: str(x))
test['Date cr'] = test['Date cr'].apply(lambda x: '-'.join([x[:4], x[4:6], x[6:]]))
test['Date deces'] = test['Date deces'].apply(lambda x: str(x))
test['Date deces'] = test['Date deces'].apply(lambda x: '-'.join([x[:4], x[4:6], x[6:]]))
test['time_surv_day'] = test.apply(strDate_to_days_cr, axis=1)
test['time_surv'] = list(map(time_survival_to_label, test['time_surv_day']))
test