In [None]:
# Include DomainML src in module path
import sys
import os

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
# Bypass issues with invoking notebook with server arguments
sys.argv.clear()
sys.argv.append("")

In [None]:
import tensorflow as tf
import random

tensorflow_seed = 7796
random_seed = 82379498237

tf.random.set_seed(tensorflow_seed)
random.seed(random_seed)

from src.features import preprocessing
from src.features.preprocessing import mimic

preprocessor_config = mimic.MimicPreprocessorConfig()
preprocessor_config.prediction_column = "level_0"
preprocessor_config.sequence_column_name = "level_all"

sequences_df = preprocessing.MimicPreprocessor(preprocessor_config).load_data()

from src.features import sequences
from src.features.sequences import transformer

sequence_column_name = preprocessor_config.sequence_column_name

transformer_config = sequences.SequenceConfig()
transformer_config.x_sequence_column_name = "level_0"
transformer_config.y_sequence_column_name = "level_3"
transformer_config.predict_full_y_sequence_wide = True

transformer = transformer.NextPartialSequenceTransformerFromDataframe(transformer_config)

metadata = transformer.collect_metadata(sequences_df, sequence_column_name)

# We cannot use sequences.generate_train or sequences.generate_test because they internally use sequences.load_sequence_transformer, which will load the incorrect transformer in our case
def generate(for_train):
    train_sequences, test_sequences = transformer._split_train_test(sequences_df, sequence_column_name)
    relevant_sequences = train_sequences if for_train else test_sequences

    for sequence in relevant_sequences:
        split_sequences = transformer._split_sequence(sequence)
        for split_sequence in split_sequences:
            transformer._translate_and_pad(split_sequence, metadata)
            yield split_sequence.x_vecs_stacked, split_sequence.y_vec

def generate_train():
    return generate(for_train=True)

def generate_test():
    return generate(for_train=False)

dataset_shuffle_buffer = 1000
dataset_shuffle_seed = 12345
batch_size = 32

train_dataset = (tf.data.Dataset.from_generator(generate_train, output_types=(tf.float32, tf.float32))
.shuffle(dataset_shuffle_buffer, dataset_shuffle_seed, reshuffle_each_iteration=True)
.batch(batch_size)
.prefetch(tf.data.experimental.AUTOTUNE))

test_dataset = (tf.data.Dataset.from_generator(generate_test,
output_types=(tf.float32, tf.float32))
.batch(batch_size)
.prefetch(tf.data.experimental.AUTOTUNE))

from src.features import knowledge

hierarchy_df = preprocessing.ICD9HierarchyPreprocessor(preprocessor_config).load_data()
hierarchy = knowledge.HierarchyKnowledge(knowledge.KnowledgeConfig())
hierarchy.build_hierarchy_from_df(hierarchy_df, metadata.x_vocab)

In [None]:
from src.training.models.gram import GramModel

model = GramModel()
model.build(metadata, hierarchy)

In [None]:
multilabel_classification = False
n_epochs = 2

model.train_dataset(train_dataset, test_dataset, multilabel_classification, n_epochs)

In [None]:
edges_to_adjust = [ 534, 430, 1665, 169, 282, 1428, 1733, 1396, 1575, 379, 1760, 280, 766, 1292, 1570, 287, 1460, 1926, 370, 1433, 446, 743, 1772, 1315, 398, 913, 1771, 978, 1436, 1761, 1473, 1595, 1820, 1263, 710, 57, 1677, 227 ] # From Jin's analysis

model.update_corrective_terms(edges_to_adjust)
model.prediction_model.trainable = False

In [None]:
model.train_dataset(train_dataset, test_dataset, multilabel_classification, n_epochs)