# Synthetic tabular dataset that resembles real data

In [208]:
import numpy as np
import pandas as pd
from scipy.stats import norm, genextreme, exponweib
from itertools import accumulate
import joblib
import matplotlib.pyplot as pp
import math
from sklearn.preprocessing import MinMaxScaler

In [486]:
numerical = ['hospital_stay_length', 'gcs', 'nb_acte', 'age']
categorical = ['gender', 'entry', 'output', 'entry_code', 'ica', 'ttt', 'ica_therapy', 'fever', 'o2_clinic', 'o2', 'hta', 'hct', 'tabagisme', 'etOH', 'diabete', 'headache', 'instable', 'vasospasme', 'ivh']
events = ['nimodipine',  'paracetamol', 'nad', 'corotrop', 'morphine', 'dve', 'atl', 'iot']
events_end = events + ['finish']

In [210]:
y = ['back2home', 'reabilitation', 'death']
y_probs = [0.443396, 0.432075, 0.124529]
n_patients = 10000

In [215]:
df = pd.DataFrame({
    'hospital_stay_length': map(round, genextreme.rvs(-0.4091639605356321, 13.2154345852118, 13.507892218123956, n_patients)),
    'gcs': map(round, norm.rvs(14.866037735849057, 1.079385463913648, n_patients)),
    'nb_acte': map(round, exponweib.rvs(1.7487636231551846, 0.7992842590334144, 0.9388125774311487, 22.6608165193314, n_patients)),
    
    'gender': np.random.choice(['F', 'M'], size=n_patients, p=[0.615094, 0.384906]),
    'entry': np.random.choice(['7', '6', '3', '13', '2', '8', '0', '1', '5'], size=n_patients, p=[0.289412, 0.254118, 0.157647, 0.145882, 0.103529, 0.023529, 0.018824, 0.004706, 0.002353]),
    'output': np.random.choice(y, size=n_patients, p=y_probs),
    'entry_code': np.random.choice(['3850', '2083', '1215', '3412', '2071', '3810', '2072', '3851', '3811', '5042', '3830', '2082', '2073', '3762', '3411', '1214', '2086', '3577', '1224', '1151', '2611', '1412', '2612', '1314', '1211', '3770', '2011', '5014', '3760'], size=n_patients, p=[0.501887, 0.192453, 0.084906, 0.033962, 0.030189, 0.020755, 0.018868, 0.016981, 0.013208, 0.011321, 0.011321, 0.009434, 0.009434, 0.007547, 0.00566, 0.003774, 0.003774, 0.003774, 0.001887, 0.001887, 0.001887, 0.001887, 0.001887, 0.001887, 0.001886, 0.001886, 0.001886, 0.001886, 0.001886]),
    'ica': np.random.choice(['ACoA', 'ACM', 'ACI', 'ACoP', 'ACA', 'PICA', 'TB', 'V', 'hyperdebit', 'ACP', 'AChoA', 'Dissection', 'ACerebS', 'BA', 'AICA', 'TN', 'Aucun', "ACoAde_l'artère_communicante_antérieur", 'ACL', 'JA'], size=n_patients, p=[0.309859, 0.205634, 0.171831, 0.073239, 0.067606, 0.056338, 0.053521, 0.011268, 0.008451, 0.005634, 0.005634, 0.005634, 0.005634, 0.002817, 0.002817, 0.002817, 0.002817, 0.002817, 0.002816, 0.002816]),
    'ttt': np.random.choice(['spire', 'remodeling', 'clip', 'web', 'flow_diverter'], size=n_patients, p=[0.933962, 0.024528, 0.020755, 0.018868, 0.001887]),
    'ica_therapy': np.random.choice(['0', 'loxen', 'amlodipine', 'nicardipin', 'lercanidipine', 'amlor', 'lercan', 'exforge', 'axeler'], size=n_patients, p=[0.966038, 0.007547, 0.007547, 0.00566, 0.003774, 0.003774, 0.001887, 0.001887, 0.001886]),

    'fever': np.random.choice(['0', 'fever'], size=n_patients, p=[0.898113, 0.101887]),
    'o2_clinic': np.random.choice(['0', 'low'], size=n_patients, p=[0.813208, 0.186792]),
    'o2': np.random.choice(['0', 'low'], size=n_patients, p=[0.722642, 0.277358]),
    'hta': np.random.choice(['0', '1'], size=n_patients, p=[0.935849, 0.064151]),
    'hct': np.random.choice(['0', '1', 'hypercholester'], size=n_patients, p=[0.949057, 0.049057, 0.001886]),
    'tabagisme': np.random.choice(['0', '1'], size=n_patients, p=[0.864151, 0.135849]),
    'etOH': np.random.choice(['0', '1'], size=n_patients, p=[0.958491, 0.041509]),
    'diabete': np.random.choice(['0', '1'], size=n_patients, p=[0.958491, 0.041509]),
    'headache': np.random.choice(['1', '0'], size=n_patients, p=[0.835849, 0.164151]),
    'instable': np.random.choice(['0', '1'], size=n_patients, p=[0.917625, 0.082375]),
    'vasospasme': np.random.choice(['1', '0'], size=n_patients, p=[0.984906, 0.015094]),
    'ivh': np.random.choice(['0', '1'], size=n_patients, p=[0.932075, 0.067925]),
    'age': map(round, genextreme.rvs(0.27689720964297965, 51.599845037531225, 14.34488206435922, n_patients))
    })


In [217]:
df

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age
0,19,15,16,F,3,back2home,2071,ACM,spire,0,fever,0,0,0,0,0,0,0,1,0,1,0,67
1,74,12,68,M,7,reabilitation,3850,ACM,spire,0,0,0,0,0,0,0,0,0,1,0,1,0,43
2,12,16,124,F,6,back2home,2083,ACoA,spire,0,0,0,0,0,0,0,0,0,1,0,1,1,27
3,35,17,25,F,7,back2home,3850,ACI,spire,0,0,low,low,0,0,0,0,0,1,0,1,0,45
4,24,16,17,F,3,reabilitation,2083,ACoP,spire,0,0,0,0,0,0,0,0,0,1,0,1,0,44
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,4,14,16,M,6,reabilitation,3850,ACM,spire,0,0,0,0,0,0,0,0,0,1,0,1,0,59
9996,27,14,15,M,6,back2home,2083,ACoA,spire,0,0,0,low,0,0,0,0,0,1,0,1,0,30
9997,17,14,16,F,13,reabilitation,2083,TB,spire,0,0,0,0,0,0,0,0,0,1,0,1,1,45
9998,38,15,57,F,13,reabilitation,2083,ACoA,spire,0,0,0,0,0,0,0,0,0,1,0,1,0,64


In [218]:
df[numerical].corr()

Unnamed: 0,hospital_stay_length,gcs,nb_acte,age
hospital_stay_length,1.0,-0.003238,-0.002064,0.003623
gcs,-0.003238,1.0,0.016813,-0.001482
nb_acte,-0.002064,0.016813,1.0,0.007254
age,0.003623,-0.001482,0.007254,1.0


In [219]:
corr_matrix = df.apply(lambda x : pd.factorize(x)[0]).corr()
np.fill_diagonal(corr_matrix.values, 0)
corr_matrix

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age
hospital_stay_length,0.0,0.003441,0.006627,-0.015681,-0.004968,-0.004668,-0.006673,-0.001038,0.013592,0.014952,-8.8e-05,-0.015371,0.007389,-0.01453,-0.008802,-0.018142,0.023614,-0.007132,0.006225,-0.015202,-0.016835,-0.001661,0.008325
gcs,0.003441,0.0,-0.020475,0.005068,-0.004144,0.009274,-0.005768,0.006791,-0.00587,0.001737,-0.014048,-0.007074,0.001766,-0.017061,0.011385,-0.013851,-0.013592,-0.007276,-0.009975,-0.005886,0.005193,0.002046,-0.007567
nb_acte,0.006627,-0.020475,0.0,0.013438,0.019135,0.000382,0.021014,0.009174,-0.005163,0.005278,-0.00356,0.011134,0.009933,0.004912,-0.016842,0.00373,-0.014818,0.013489,-0.005267,0.001616,0.006992,-0.001599,0.013237
gender,-0.015681,0.005068,0.013438,0.0,-0.035413,0.012712,-0.003638,-0.015928,-0.003516,0.006581,-0.001147,-0.003772,-0.006523,0.015892,0.003453,-0.002479,0.003702,0.00407,0.023242,-0.000754,-0.009496,-0.002067,-0.007891
entry,-0.004968,-0.004144,0.019135,-0.035413,0.0,0.001492,-0.008594,-0.002313,-0.002826,-0.012013,-0.005556,0.004464,0.005023,-0.007352,-0.010053,-0.013016,0.023406,0.00155,-0.005926,-0.004109,0.011664,0.011099,0.009073
output,-0.004668,0.009274,0.000382,0.012712,0.001492,0.0,-0.019956,-0.017934,-0.016075,-0.007557,-0.012788,-0.000261,0.002137,0.000709,-0.011939,-0.003513,0.001202,0.003342,-0.002174,-0.015553,-0.009729,-0.007101,0.00284
entry_code,-0.006673,-0.005768,0.021014,-0.003638,-0.008594,-0.019956,0.0,-0.017849,-0.007748,-0.001838,0.001303,5.7e-05,-0.007865,-0.015579,0.000389,0.00184,0.00335,-0.004186,-0.010736,0.002902,-0.009555,0.007168,-0.002663
ica,-0.001038,0.006791,0.009174,-0.015928,-0.002313,-0.017934,-0.017849,0.0,-0.009254,-0.010821,0.003587,0.013382,-0.005423,0.012428,-0.00972,-0.011071,-0.001049,-0.002659,-0.00091,-0.011623,0.003974,0.006186,0.01139
ttt,0.013592,-0.00587,-0.005163,-0.003516,-0.002826,-0.016075,-0.007748,-0.009254,0.0,-0.003084,-0.003802,-0.00399,-0.006235,0.004344,0.004331,-0.005514,0.001835,-0.01683,-0.010379,-0.01914,0.014922,0.018088,-0.01492
ica_therapy,0.014952,0.001737,0.005278,0.006581,-0.012013,-0.007557,-0.001838,-0.010821,-0.003084,0.0,0.014065,-0.004795,-0.008048,-0.00212,-0.005228,-0.000817,0.011078,0.019537,-0.006354,-0.006523,0.003907,0.000781,0.019123


In [220]:
df[categorical] = df[categorical].apply(lambda x : pd.factorize(x)[0])
df

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age
0,19,15,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,67
1,74,12,68,1,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,43
2,12,16,124,0,2,0,2,1,0,0,1,0,0,0,0,0,0,0,0,0,0,1,27
3,35,17,25,0,1,0,1,2,0,0,1,1,1,0,0,0,0,0,0,0,0,0,45
4,24,16,17,0,0,1,2,3,0,0,1,0,0,0,0,0,0,0,0,0,0,0,44
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,4,14,16,1,2,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,59
9996,27,14,15,1,2,0,2,1,0,0,1,0,1,0,0,0,0,0,0,0,0,0,30
9997,17,14,16,0,3,1,2,8,0,0,1,0,0,0,0,0,0,0,0,0,0,1,45
9998,38,15,57,0,3,1,2,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,64


In [223]:
correlation_matrix = joblib.load('nantes_correlations.joblib')

In [339]:
# Cholesky decomposition to introduce correlations

L = np.linalg.cholesky(correlation_matrix)
synthetic_data = df @ L
synthetic_data.columns = df.columns

print(synthetic_data[:5])

   hospital_stay_length        gcs     nb_acte    gender     entry    output  \
0             27.408750  15.463376   16.876151  8.388600 -1.435159  6.928455   
1             99.555309  16.319167   65.199637  6.294708  0.121499  5.376513   
2             56.288059  24.284115  116.950650  3.207707  1.551535  2.809394   
3             46.042605  18.332474   24.961537  5.790602  0.015333  4.699021   
4             32.338527  16.943132   17.607540  5.743600 -0.618380  5.482162   

   entry_code       ica       ttt  ica_therapy     fever  o2_clinic        o2  \
0    1.186922 -2.171281  3.774313     1.479777 -0.902083  -0.820476 -8.266053   
1    1.739191 -1.295576  2.391743     0.931995  0.401492  -0.526574 -5.305079   
2    2.381174  0.224239  1.475583     0.522959  0.628369  -0.331877 -3.336005   
3    1.723149  0.542352  2.573684     1.057753  0.584018   0.476297 -4.586691   
4    2.723126  1.654733  2.448076     0.954081  0.388028  -0.538820 -5.428453   

        hta       hct  tabagisme

In [340]:
synthetic_data.corr()

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age
hospital_stay_length,1.0,0.183486,0.201999,0.010107,-0.006785,0.00697,0.008417,-0.00266,0.010431,0.014277,-0.006432,-0.003815,-0.010909,0.015723,0.015644,-0.00105,0.014086,0.008772,-0.010487,0.003169,0.008755,0.012798,0.012444
gcs,0.183486,1.0,0.908385,-0.045559,0.01558,-0.020232,0.169234,-0.024295,-0.027584,0.011965,0.022829,0.031656,0.042688,-0.023331,-0.040943,0.051766,-0.03562,-0.040403,0.03448,-0.008465,-0.036486,-0.037082,-0.040941
nb_acte,0.201999,0.908385,1.0,0.019869,0.003393,0.019974,0.020391,0.004743,0.012534,0.01665,-0.009831,0.00634,-0.014975,0.015598,0.017351,-0.006474,0.014578,0.017372,-0.021407,0.012314,0.021088,0.018038,0.018511
gender,0.010107,-0.045559,0.019869,1.0,-0.266195,0.865478,-0.001634,0.055479,0.795452,0.331725,-0.506957,-0.401468,-0.917029,0.757122,0.942534,-0.76463,0.898393,0.91621,-0.909598,0.359554,0.902518,0.895684,0.949214
entry,-0.006785,0.01558,0.003393,-0.266195,1.0,-0.150053,0.310084,0.132004,-0.163458,-0.13086,0.085923,0.060687,0.189132,-0.177315,-0.196737,0.139477,-0.180087,-0.190401,0.18891,-0.090412,-0.18245,-0.182331,-0.195625
output,0.00697,-0.020232,0.019974,0.865478,-0.150053,1.0,0.135943,-0.127794,0.751321,0.300544,-0.500619,-0.401059,-0.876169,0.72646,0.905194,-0.736838,0.85626,0.882516,-0.87856,0.347886,0.865928,0.859042,0.910529
entry_code,0.008417,0.169234,0.020391,-0.001634,0.310084,0.135943,1.0,-0.028899,0.038757,0.016258,-0.028963,-0.027306,-0.050969,0.033926,0.050869,-0.035946,0.04967,0.047858,-0.052058,0.024132,0.045812,0.04958,0.051155
ica,-0.00266,-0.024295,0.004743,0.055479,0.132004,-0.127794,-0.028899,1.0,-0.092045,-0.041849,0.065519,0.056202,0.095197,-0.085989,-0.10318,0.068083,-0.097739,-0.100168,0.100411,-0.046161,-0.097229,-0.094095,-0.102119
ttt,0.010431,-0.027584,0.012534,0.795452,-0.163458,0.751321,0.038757,-0.092045,1.0,0.26382,-0.4541,-0.347986,-0.808058,0.662833,0.829014,-0.657935,0.788306,0.805294,-0.804247,0.319894,0.794724,0.788141,0.833582
ica_therapy,0.014277,0.011965,0.01665,0.331725,-0.13086,0.300544,0.016258,-0.041849,0.26382,1.0,-0.170888,-0.122533,-0.332595,0.355023,0.345568,-0.256044,0.319713,0.341879,-0.334786,0.132927,0.325273,0.316788,0.340789


In [348]:

def round_with_prob(x, n_cats):
    dec, ent = math.modf(x)
    ent = int(ent)
    return np.random.choice([ent, ent+1], size=1, p=[1-dec, dec])[0] % n_cats

def cut(x):
    dec, ent = math.modf(x)
    ent = int(ent)
    return ent

In [349]:
for feature in ['gender', 'fever', 'o2_clinic', 'o2', 'hta', 'tabagisme', 'etOH', 'diabete', 'headache', 'instable', 'vasospasme', 'ivh']:
    synthetic_data[[feature]] = MinMaxScaler(feature_range=(0,1.99)).fit_transform(synthetic_data[[feature]])
    synthetic_data[feature] = synthetic_data[feature].apply(lambda x: cut(x))

for feature, n_classes in zip(['entry', 'entry_code', 'ica', 'ttt', 'ica_therapy', 'hct'], [9, 29, 20, 5, 9, 3]):
    synthetic_data[[feature]] = MinMaxScaler(feature_range=(0, n_classes-0.01)).fit_transform(synthetic_data[[feature]])
    synthetic_data[feature] = synthetic_data[feature].apply(lambda x: cut(x))

In [350]:
arr = synthetic_data['output']
sorted_indices = np.argsort(arr)
transformed_array = np.zeros_like(arr)

transformed_array[sorted_indices[:4414]] = 0
transformed_array[sorted_indices[4414:8747]] = 1
transformed_array[sorted_indices[8747:]] = 2

synthetic_data['output'] = transformed_array

In [360]:
synthetic_data

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age
0,27.408750,15.463376,16.876151,1,0,1.0,0,0,1,1,0,0,0,0,1,0,1,1,0,0,0,1,64.129138
1,99.555309,16.319167,65.199637,0,1,0.0,1,1,1,0,1,0,1,0,1,0,0,0,1,0,0,0,41.157506
2,56.288059,24.284115,116.950650,0,3,0.0,2,3,0,0,1,0,1,0,0,1,0,0,1,0,0,0,25.843085
3,46.042605,18.332474,24.961537,0,1,0.0,1,3,1,0,1,1,1,0,1,0,0,0,1,0,0,0,43.071809
4,32.338527,16.943132,17.607540,0,1,0.0,2,4,1,0,1,0,1,0,1,0,0,0,1,0,0,0,42.114657
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,12.304598,14.471077,16.788628,1,2,1.0,1,1,1,1,1,0,0,0,1,0,0,1,0,0,0,0,56.471927
9996,33.985456,14.730414,14.941905,0,3,0.0,2,2,0,0,1,0,1,0,0,1,0,0,1,0,0,0,28.714539
9997,25.174506,14.429871,16.328751,0,3,0.0,2,9,1,0,1,0,1,0,1,0,0,0,1,0,0,1,43.071809
9998,60.516443,18.268104,55.190213,1,3,1.0,2,2,1,1,1,0,0,0,1,0,1,1,0,0,0,1,61.257684


In [364]:
real_correlation = synthetic_data.corr().to_numpy()
synthetic_data.corr()

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age
hospital_stay_length,1.0,0.183486,0.201999,0.009486,-0.008059,0.00833,0.007872,-0.00371,0.006001,0.013762,-0.005295,4.1e-05,-0.014577,0.013163,0.004814,0.000247,0.008896,0.008952,-0.005927,-0.001045,-0.002758,0.010226,0.012444
gcs,0.183486,1.0,0.908385,-0.029363,0.011032,-0.01647,0.169965,-0.024331,-0.025232,0.008961,0.004375,0.012323,0.021928,-0.022463,-0.040914,0.029945,-0.030649,-0.028945,0.018644,0.006057,-0.027196,-0.0285,-0.040941
nb_acte,0.201999,0.908385,1.0,0.019259,0.000148,0.015594,0.020023,0.005166,0.008531,0.017407,-0.004271,0.012643,-0.024903,0.00029,0.004694,-0.014051,0.008429,0.017086,-0.026286,0.006655,0.013163,0.014837,0.018511
gender,0.009486,-0.029363,0.019259,1.0,-0.196055,0.671314,-0.020288,0.029681,0.496325,0.375184,-0.046987,-0.006894,-0.646578,0.339295,0.553002,-0.390737,0.738511,0.755226,-0.647296,0.017269,0.467224,0.730876,0.768337
entry,-0.008059,0.011032,0.000148,-0.196055,1.0,-0.119934,0.307236,0.127761,-0.128475,-0.132774,0.001984,-0.023013,0.147032,-0.099824,-0.147896,0.093337,-0.136854,-0.144996,0.141948,-0.016824,-0.110853,-0.13503,-0.182603
output,0.00833,-0.01647,0.015594,0.671314,-0.119934,1.0,0.103999,-0.111625,0.556091,0.337135,-0.106549,-0.015951,-0.577471,0.442397,0.641054,-0.400218,0.689996,0.686791,-0.591162,0.031456,0.601773,0.68039,0.803509
entry_code,0.007872,0.169965,0.020023,-0.020288,0.307236,0.103999,1.0,-0.025451,0.013112,0.009409,-0.008323,-0.006781,-0.011606,0.023915,0.014981,-0.009786,0.020364,0.017029,-0.011368,0.004786,0.019863,0.021715,0.026844
ica,-0.00371,-0.024331,0.005166,0.029681,0.127761,-0.111625,-0.025451,1.0,-0.080162,-0.048093,0.022404,0.019488,0.076163,-0.056749,-0.084379,0.046387,-0.080027,-0.081373,0.077793,-0.013717,-0.069455,-0.082747,-0.101324
ttt,0.006001,-0.025232,0.008531,0.496325,-0.128475,0.556091,0.013112,-0.080162,1.0,0.230553,-0.075673,-0.007683,-0.448207,0.402734,0.804374,-0.384643,0.531881,0.50317,-0.452827,0.024662,0.665807,0.527878,0.72839
ica_therapy,0.013762,0.008961,0.017407,0.375184,-0.132774,0.337135,0.009409,-0.048093,0.230553,1.0,-0.022097,0.026223,-0.378048,0.215089,0.279111,-0.204533,0.342727,0.394884,-0.374385,0.014815,0.222184,0.334808,0.401385


In [365]:
correlation_matrix

array([[ 1.00000000e+00,  4.78437293e-02,  3.43283713e-01,
         1.02508632e-01, -2.06727316e-02,  1.37298195e-01,
        -1.44883511e-03,  4.99275969e-02,  5.50659589e-02,
         4.95612003e-02,  9.19353581e-03,  1.90268182e-02,
         6.45738240e-02,  4.72897530e-02, -2.28035890e-02,
        -3.14530814e-02,  2.18913468e-02,  9.54496385e-03,
        -1.40334023e-03, -6.15245909e-02, -1.87250102e-02,
         5.45164642e-02,  3.28142502e-02],
       [ 4.78437293e-02,  1.00000000e+00,  8.55006162e-02,
        -2.70010086e-02, -1.01092530e-01,  4.50397215e-02,
         1.00135670e-01, -2.01336101e-02,  1.00980381e-01,
         7.35329736e-02,  2.57710229e-02,  4.56393290e-02,
         2.13758747e-02,  6.60016367e-02,  1.57509176e-02,
         8.13476557e-02,  2.71002495e-01, -2.69150843e-02,
        -5.73157558e-02, -2.81794410e-02, -1.60112782e-02,
        -2.10220072e-02, -7.76173766e-03],
       [ 3.43283713e-01,  8.55006162e-02,  1.00000000e+00,
         5.49758801e-02, -1.5

In [495]:
transitions = pd.read_csv('./care_transitions_probs.csv', index_col=0)
transitions

Unnamed: 0,nimodipine,paracetamol,nad,corotrop,morphine,dve,atl,iot
nimodipine,0.0,0.526611,0.227941,0.018868,0.052381,0.043165,0.0,0.012384
paracetamol,0.202381,0.0,0.375,0.084906,0.366667,0.136691,0.081081,0.065015
nad,0.011905,0.008403,0.0,0.349057,0.07619,0.208633,0.0,0.080495
corotrop,0.011905,0.008403,0.022059,0.0,0.047619,0.021583,0.108108,0.018576
morphine,0.047619,0.008403,0.102941,0.056604,0.0,0.115108,0.0,0.04644
dve,0.107143,0.019608,0.014706,0.018868,0.009524,0.0,0.108108,0.44582
atl,0.0,0.0,0.0,0.0,0.009524,0.007194,0.0,0.003096
iot,0.428571,0.151261,0.154412,0.113208,0.066667,0.230216,0.243243,0.0
finish,0.190476,0.277311,0.102941,0.358491,0.371429,0.23741,0.459459,0.328173


In [500]:
start_probs = [0.47381546, 0.09476309, 0.00997506, 0, 0.00997506, 0.24189526, 0.00249377, 0.16708229, 0]

def generate_care_path():
    event = np.random.choice(events_end, size=1, p=start_probs)[0]
    path = [event]

    while event != 'finish':
        event = np.random.choice(events_end, size=1, p=transitions[event].values)[0]
        if event in path:
            event = 'finish'
        path += [event]
    
    return path

def generate_times_path(path):
    path = path[:-1]
    indv_times = map(round, norm.rvs(24, 5, len(path)))
    acc_times = list(accumulate(indv_times))

    sol = [-1] * len(events)
    for i, e in enumerate(path):
        sol[events.index(e)] = acc_times[i]
    
    return sol

In [602]:
df_events = pd.DataFrame([generate_times_path(generate_care_path()) for _ in range(n_patients)], columns=events)
df_events

Unnamed: 0,nimodipine,paracetamol,nad,corotrop,morphine,dve,atl,iot
0,19,-1,-1,-1,-1,-1,-1,46
1,25,76,-1,-1,-1,-1,-1,52
2,-1,109,80,56,-1,29,-1,143
3,23,-1,-1,-1,47,-1,-1,75
4,28,64,-1,-1,-1,44,-1,-1
...,...,...,...,...,...,...,...,...
9995,-1,69,-1,-1,41,18,-1,85
9996,-1,-1,59,-1,-1,24,-1,-1
9997,25,47,-1,-1,94,77,-1,-1
9998,-1,-1,-1,-1,-1,-1,-1,19


In [611]:
full_data = pd.concat([synthetic_data, df_events], axis=1)
full_data

Unnamed: 0,hospital_stay_length,gcs,nb_acte,gender,entry,output,entry_code,ica,ttt,ica_therapy,fever,o2_clinic,o2,hta,hct,tabagisme,etOH,diabete,headache,instable,vasospasme,ivh,age,nimodipine,paracetamol,nad,corotrop,morphine,dve,atl,iot
0,27.408750,15.463376,16.876151,1,0,1.0,0,0,1,1,0,0,0,0,1,0,1,1,0,0,0,1,64.129138,19,-1,-1,-1,-1,-1,-1,46
1,99.555309,16.319167,65.199637,0,1,0.0,1,1,1,0,1,0,1,0,1,0,0,0,1,0,0,0,41.157506,25,76,-1,-1,-1,-1,-1,52
2,56.288059,24.284115,116.950650,0,3,0.0,2,3,0,0,1,0,1,0,0,1,0,0,1,0,0,0,25.843085,-1,109,80,56,-1,29,-1,143
3,46.042605,18.332474,24.961537,0,1,0.0,1,3,1,0,1,1,1,0,1,0,0,0,1,0,0,0,43.071809,23,-1,-1,-1,47,-1,-1,75
4,32.338527,16.943132,17.607540,0,1,0.0,2,4,1,0,1,0,1,0,1,0,0,0,1,0,0,0,42.114657,28,64,-1,-1,-1,44,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,12.304598,14.471077,16.788628,1,2,1.0,1,1,1,1,1,0,0,0,1,0,0,1,0,0,0,0,56.471927,-1,69,-1,-1,41,18,-1,85
9996,33.985456,14.730414,14.941905,0,3,0.0,2,2,0,0,1,0,1,0,0,1,0,0,1,0,0,0,28.714539,-1,-1,59,-1,-1,24,-1,-1
9997,25.174506,14.429871,16.328751,0,3,0.0,2,9,1,0,1,0,1,0,1,0,0,0,1,0,0,1,43.071809,25,47,-1,-1,94,77,-1,-1
9998,60.516443,18.268104,55.190213,1,3,1.0,2,2,1,1,1,0,0,0,1,0,1,1,0,0,0,1,61.257684,-1,-1,-1,-1,-1,-1,-1,19


In [612]:
full_data.to_csv('syn_data.csv')