# Setup

Before we can start using the `tensor2tensor` models, we first have to get our data into a format that `tensor2tensor` can digest. This means defining a custom `Problem` as follows:

In [1]:
#@title Run this only once - Sets up TF Eager execution.

import tensorflow as tf

# Enable Eager execution - useful for seeing the generated data.
tf.enable_eager_execution()

In [2]:
#@title Setting a random seed.

from tensor2tensor.utils import trainer_lib

# Set a seed so that we have deterministic outputs.
RANDOM_SEED = 301
trainer_lib.set_random_seed(RANDOM_SEED)

W0802 18:59:01.531362 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/expert_utils.py:68: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0802 18:59:02.064932 139639908783936 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0802 18:59:02.987431 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/utils/metrics_hook.py:28: The name tf.train.SessionRunHook is deprecated. Please use tf.estimator.SessionRunHook instead.

W0802 18:59:02.991913 139639908783936 deprecation_

In [3]:
#@title Run for setting up directories.

import os

# Setup and create directories.
DATA_DIR = os.path.expanduser("../data/t2t_experiments/transformer_moe/full_context/data")
OUTPUT_DIR = os.path.expanduser("../data/t2t_experiments/transformer_moe/full_context/output")
TMP_DIR = os.path.expanduser("/mnt/")

# Create them.
tf.gfile.MakeDirs(DATA_DIR)
tf.gfile.MakeDirs(OUTPUT_DIR)
tf.gfile.MakeDirs(TMP_DIR)

In [4]:
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry

# Define the problem

In [5]:
@registry.register_problem

class MimicDischargeSummaries(text_problems.Text2TextProblem):
    
    @property
    def is_generate_per_split(self):
        # our data already has pre-existing splits so we return true
        return True

    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        
        del tmp_dir
        
        _train = (dataset_split == problem.DatasetSplit.TRAIN)
        _eval = (dataset_split == problem.DatasetSplit.EVAL)
        
        dataset = "train" if _train else "val" if _eval else "test"
        
        full_context = "full_context" in str(data_dir) # returns a boolean
        directory = "../data/preprocessed/"
        tgt = directory + "tgt-" + dataset + ".txt"

        if full_context == True:
            src = directory + "src-" + dataset + ".txt"
        else:
            directory += "other_contexts/" 
            context = str(data_dir)[39:-5] # this index needs to be changed if file paths are changed
            src = directory + "src-" + dataset + "-" + context + ".txt"
        
        f_src = open(src,'r')
        f_tgt = open(tgt,'r')
        
        context_data = f_src.readline()
        discharge_summary = f_tgt.readline()

        while context_data:
            yield {
              "inputs"  : context_data,
              "targets" : discharge_summary,
            }
            
            context_data = f_src.readline()
            discharge_summary = f_tgt.readline()
            
        f_src.close()
        f_tgt.close()

    @property
    def vocab_type(self):
        # SUBWORD and CHARACTER are fully invertible -- but SUBWORD provides a good
        # tradeoff between CHARACTER and TOKEN.
        return text_problems.VocabType.SUBWORD

    @property
    def approx_vocab_size(self):
        # Approximate vocab size to generate. Only for VocabType.SUBWORD.
        return 2**15  # ~32k - this is the default setting

    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 80
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 10
        }, {
            "split": problem.DatasetSplit.TEST,
            "shards": 10
        }]

# Generate the data

First, we instantiate the problem and run it for the full context data.

In [6]:
mimic_problem = MimicDischargeSummaries()
#mimic_problem.generate_data(DATA_DIR, TMP_DIR)

Now, we run it in a loop instead for each individual context type.

In [7]:
context_list = ['h','h-gae','h-gae-d','h-gae-p','h-gae-d-p','h-gae-d-p-m','h-gae-d-p-m-t','h-gae-d-p-m-l']

for context in context_list:
    # Setup and create directories.
    DATA_DIR = os.path.expanduser("../data/t2t_experiments/other_contexts/"+context+"/data")
    OUTPUT_DIR = os.path.expanduser("../data/t2t_experiments/other_contexts/"+context+"/output")
    TMP_DIR = os.path.expanduser("/mnt/")

    # Create them.
    tf.gfile.MakeDirs(DATA_DIR)
    tf.gfile.MakeDirs(OUTPUT_DIR)
    tf.gfile.MakeDirs(TMP_DIR)
    
    mimic_problem.generate_data(DATA_DIR, TMP_DIR)

W0802 19:05:31.212889 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:343: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.

W0802 19:05:31.213634 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:349: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.

W0802 19:08:39.032947 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/generator_utils.py:355: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.

W0802 19:08:39.034780 139639908783936 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_encoder.py:944: The name tf.gfil

# View the generated data

In [7]:
tfe = tf.contrib.eager

Modes = tf.estimator.ModeKeys

# We can iterate over our examples by making an iterator and calling next on it.
eager_iterator = tfe.Iterator(mimic_problem.dataset(Modes.EVAL, DATA_DIR))
example = eager_iterator.next()

input_tensor = example["inputs"]
target_tensor = example["targets"]

# The tensors are actually encoded using the generated vocabulary file -- you
# can inspect the actual vocab file in DATA_DIR.
print("Tensor Input: " + str(input_tensor))
print("Tensor Target: " + str(target_tensor))

W0729 17:59:46.342158 140696631490368 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/text_problems.py:394: The name tf.VarLenFeature is deprecated. Please use tf.io.VarLenFeature instead.

W0729 17:59:46.343035 140696631490368 deprecation_wrapper.py:119] From /home/aa5118/anaconda3/envs/tf/lib/python3.7/site-packages/tensor2tensor/data_generators/problem.py:705: The name tf.FixedLenFeature is deprecated. Please use tf.io.FixedLenFeature instead.



Tensor Input: tf.Tensor(
[   80    72    54  1451     6    42     6   121    18    55    72   398
   368     3   228     4   365     3   269     4   364     3  4213     4
   366     3  2468  3414  1737  1534  2069     7  9304    60    95     7
   659   825   329    60   518    77   661     7  6604    14  5907    15
   101  2002     7  4855   832  4317     7  3060    60    95    60   132
   438    15  3903  3060     7    95   732   234     7   101   661   179
  3527     7   469    15   101   180   131     4   358   308   367     3
    36     9    45   846   208   291     5   124   275   289     7   349
   291     5    97    48  2648   281     7   369  2174     5    21   136
    30   124   275   889   289     4   228     3   720    11    85     4
   337     3   237   243     5 10137     5   116     6    83     5    29
     7   214     5  1005     5    48     6    38     7   252   106   112
     5    27     9  1342     5   114     6    83     5    29     7   266
     5   134     9    41  

Below cell is not executed in order to protect patient privacy. Executing it will show the decoded context data and discharge summary

In [None]:
# We use the encoders to decode the tensors to the actual input text.
input_encoder = mimic_problem.get_feature_encoders(
    data_dir=DATA_DIR)["inputs"]
target_encoder = mimic_problem.get_feature_encoders(
    data_dir=DATA_DIR)["targets"]

input_decoded = input_encoder.decode(input_tensor.numpy())
target_decoded = target_encoder.decode(target_tensor.numpy())

print("Decoded Input: " + input_decoded)
print("Decoded Target: " + target_decoded)