In [None]:
# Include DomainML src in module path
import sys
import os

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
# Bypass issues with invoking notebook with server arguments
sys.argv.clear()
sys.argv.append("")

In [None]:
from pathlib import Path
from typing import List

from src.features import preprocessing

class MinimalMimicPreprocessorConfig:
    admission_file: Path = Path("data/ADMISSIONS.csv")
    diagnosis_file: Path = Path("data/DIAGNOSES_ICD.csv")
    min_admissions_per_user: int = 2
    add_icd9_info_to_sequences: bool = False # Avoid this if possible, because web requests are slow
    cluster_file: Path = Path("data/invalid.file") # This file must not exist
    replace_keys: List[str] = []
    prediction_column: str = ""

sequences_df = preprocessing.MimicPreprocessor(MinimalMimicPreprocessorConfig()).load_data()

In [None]:
from src.features import sequences

sequence_column_name = "icd9_code_converted_3digits"

transformer = sequences.load_sequence_transformer()
metadata = transformer.collect_metadata(sequences_df, sequence_column_name)

In [None]:
import tensorflow as tf
import random

sequence_df_pkl_file = "data/sequences_df.pkl"
dataset_shuffle_buffer = 1000
dataset_shuffle_seed = 12345
batch_size = 32

tensorflow_seed = 7796
random_seed = 82379498237

tf.random.set_seed(tensorflow_seed)
random.seed(random_seed)

sequences_df.to_pickle(sequence_df_pkl_file)

train_dataset = (tf.data.Dataset.from_generator(sequences.generate_train,
args=(sequence_df_pkl_file, sequence_column_name), output_types=(tf.float32, tf.float32))
.cache("")
.shuffle(dataset_shuffle_buffer, dataset_shuffle_seed, reshuffle_each_iteration=True)
.batch(batch_size)
.prefetch(tf.data.experimental.AUTOTUNE))

test_dataset = (tf.data.Dataset.from_generator(sequences.generate_test,
args=(sequence_df_pkl_file, sequence_column_name),
output_types=(tf.float32, tf.float32))
.cache("")
.batch(batch_size)
.prefetch(tf.data.experimental.AUTOTUNE))

In [None]:
from src.features import knowledge

class MinimalICD9HierarchyPreprocessorConfig:
    replace_keys: List[str] = []
    prediction_column: str = ""
    icd9_file: Path = Path("data/icd9.csv")

hierarchy_df = preprocessing.ICD9HierarchyPreprocessor(MinimalICD9HierarchyPreprocessorConfig()).load_data()
hierarchy = knowledge.HierarchyKnowledge(knowledge.KnowledgeConfig())
hierarchy.build_hierarchy_from_df(hierarchy_df, metadata.x_vocab)

In [None]:
from src.training import models

model = models.GramModel()
model.build(metadata, hierarchy)

In [None]:
multilabel_classification = False
n_epochs = 10

model.train_dataset(train_dataset, test_dataset, multilabel_classification, n_epochs)