# Training an ML Model on Tensorflow Datasets
## Prerequisites

In [1]:
import datetime
import gc
import glob
import json
import os
import random
import shutil
import time
from typing import Iterable, Callable, Dict, Any, Tuple, Optional, List, Union

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python import keras as K

from mmproteo.utils import log, paths, utils, visualization
from mmproteo.utils.formats.mz import FilteringProcessor, MzmlidFileStatsCreator
from mmproteo.utils.formats.tf_dataset import Parquet2DatasetFileProcessor
from mmproteo.utils.processing import ItemProcessor

In [2]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 1000)

In [3]:
logger = log.DummyLogger(verbose=False)

INFO: Printing to Stdout


## Configuration

In [4]:
pwd

'/tf/workspace/notebooks'

In [5]:
PROJECT = "PXD010000"
DUMP_PATH = os.path.join("..", "dumps", PROJECT)
TRAINING_COLUMNS_DUMP_PATH = os.path.join(DUMP_PATH, "training_columns")
FILES_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "*_mzmlid.parquet")
STATISTICS_FILE_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "statistics.parquet")
DATASET_DUMP_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "tf_datasets")
PROCESSING_FILE_PATH = os.path.join(DATASET_DUMP_PATH, "processing_info.json")

In [6]:
SEQ = 'peptide_sequence'
MZ = 'mz_array'
INT = 'intensity_array'

In [7]:
TRAINING_DATA_COLUMNS = [MZ, INT]
TARGET_DATA_COLUMNS = [SEQ]
SPLIT_VALUE_COLUMNS = ['species', 'istrain']

In [8]:
with open(PROCESSING_FILE_PATH, 'r') as file:
    PROCESSING_INFO = json.loads(file.read())
PROCESSING_INFO

{'padding_characters': {'peptide_sequence': '_',
  'mz_array': 0.0,
  'intensity_array': 0.0},
 'padding_lengths': {'mz_array': 2354,
  'intensity_array': 2354,
  'peptide_sequence': 50},
 'idx_to_char': {'0': 'A',
  '1': 'C',
  '2': 'D',
  '3': 'E',
  '4': 'F',
  '5': 'G',
  '6': 'H',
  '7': 'I',
  '8': 'K',
  '9': 'L',
  '10': 'M',
  '11': 'M(Oxidation)',
  '12': 'N',
  '13': 'P',
  '14': 'Q',
  '15': 'R',
  '16': 'S',
  '17': 'T',
  '18': 'V',
  '19': 'W',
  '20': 'Y',
  '21': '_'},
 'normalization': {'intensity_array': '<function base_peak_normalize at 0x7fa6046d5158>'},
 'split_value_columns': ['species', 'istrain'],
 'training_data_columns': ['mz_array', 'intensity_array'],
 'target_data_columns': ['peptide_sequence']}

## Loading Tensorflow Datasets

### ... by species annotation with train-test-eval split

In [9]:
KEEP_CACHE = True  # currently, there is no cache; the flag only disables benchmarking

In [10]:
dataset_file_paths = paths.assign_wildcard_paths_to_splits_grouped_by_path_position_value(
    wildcard_path = os.path.join(
        DATASET_DUMP_PATH, 
        '*',  # filename
        '*',  # species
        '*'   # istrain
    ),
    path_position = -2,
    splits = {
            "Train": 0.4,
            "Test": 0.5,
            "Eval": 0.6
        },
    paths_dump_file = os.path.join(
            DATASET_DUMP_PATH,
            "dataset_file_paths.json"
        ),
    skip_existing = KEEP_CACHE,
    logger = logger
)

print()
print("assigned dataset files:")
visualization.print_list_length_in_dict(dataset_file_paths)

INFO: found file paths dump '../dumps/PXD010000/training_columns/tf_datasets/dataset_file_paths.json'

assigned dataset files:
#Train = 89
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_C_indologenes_LIB_aerobic_02_03May16_Samwise_16-03-32_mzmlid.parquet/Chryseobacterium_indologenes/Train
#Test = 17
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_A_cryptum_FeTSB_anaerobic_1_01Jun16_Pippin_16-03-39_mzmlid.parquet/Acidiphilium_cryptum_JF-5/Train
#Eval = 29
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_B_fragilis_CMcarb_anaerobic_01_01Feb16_Arwen_15-07-13_mzmlid.parquet/Bacteroides_fragilis_638R/Train


### Loading corresponding TF datasets

In [None]:
element_spec = ((tf.TensorSpec(shape=(PADDING_LENGTHS[MZ],), dtype=tf.float32), 
  tf.TensorSpec(shape=(PADDING_LENGTHS[INT],), dtype=tf.float32)),
(tf.TensorSpec(shape=(PADDING_LENGTHS[SEQ],), dtype=tf.int8)))
element_spec

In [None]:
merged_datasets = {
    training_data_type: tf.data.Dataset.from_tensor_slices(paths).interleave(
        lambda path: 
            tf.data.experimental.load(
                path=path, 
                element_spec=element_spec, 
                compression='GZIP'
            ),
         num_parallel_calls=os.cpu_count(),
         deterministic=False
    )
    for training_data_type, paths in dataset_file_paths.items()
}

merged_datasets

## Configuring Tensorflow Datasets

In [None]:
BATCH_SIZE = 32

# although not all data fits into a 100k buffer, the interleaving should make it sufficiently random
SHUFFLE_BUFFER_SIZE = 2*10**5

### Caching (currently abandoned because of too high RAM usage)

### Preloading

In [None]:
def fill_cache(dataset, name: Optional[str] = None):
    """
    Use a benchmark to once process the whole dataset.
    """
    if name is not None:
        print(f"{name}:")
    display(tfds.benchmark(dataset))
    gc.collect()
    logger.info("filled a cache - waiting 10 seconds")
    print()
    time.sleep(10)
    return dataset

In [None]:
if not KEEP_CACHE:
    merged_datasets = {
        training_data_type: fill_cache(dataset, name=training_data_type)
        for training_data_type, dataset in merged_datasets.items()
    }

### Shuffling, Batching, Prefetching

In [None]:
merged_datasets = {
    training_data_type: dataset
        .shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True)
        .batch(BATCH_SIZE, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    for training_data_type, dataset in merged_datasets.items()
}
merged_datasets

In [None]:
TRAINING_TYPE = 'Train'
TEST_TYPE = 'Test'
EVAL_TYPE = 'Eval'

## Building the Tensorflow Model

In [None]:
named_input_layers = {col: tf.keras.layers.Input(shape=(PADDING_LENGTHS[col],), name=col) for col in TRAINING_DATA_COLUMNS}
named_input_layers

In [None]:
named_input_layers_list = [ named_input_layers[col] for col in TRAINING_DATA_COLUMNS ]
named_input_layers_list

In [None]:
masked_input_layers = {
    col: tf.keras.layers.Masking(mask_value=PADDING_CHARACTERS[col], name=f"masked_{col}")(input_layer)
    for col, input_layer in named_input_layers.items()
}
masked_input_layers

In [None]:
masked_input_layers_list = [ masked_input_layers[col] for col in TRAINING_DATA_COLUMNS ]
masked_input_layers_list

In [None]:
class MaskedLoss(K.losses.LossFunctionWrapper):
    def __init__(self, loss_function, masking_value, name='masked_loss', reduction=tf.keras.losses.Reduction.NONE):
        def _masked_loss(y_true, y_pred):
            y_true = tf.squeeze(y_true, name="masked_loss__squeezed_y_true")
            y_pred = tf.squeeze(y_pred, name="masked_loss__squeezed_y_pred")
            #print(y_true)
            #print(y_pred)
            length_mask = tf.equal(y_true, masking_value, name="masked_loss__is_masking_value")
            #print(length_mask)
            length_mask = tf.cast(length_mask, tf.float32, name="masked_loss__is_masking_value_float")
            length_mask = tf.math.subtract(
                tf.constant(
                    value=1, 
                    dtype=tf.float32
                ), length_mask, name="masked_loss__is_masking_value_inverted")
            lengths = tf.math.reduce_sum(length_mask, axis=-1, name="masked_loss__sum_to_get_lengths")
            #print(lengths)
            lengths = tf.math.add(lengths, 1, name="masked_loss__sum_to_include_first_padding") # to also include the first padding character
            #print(lengths)
            mask = tf.sequence_mask(
                lengths=lengths,
                maxlen=y_pred.shape[-2],  # pre-last dimension = padding length; last dimension = one-hot-encoded alphabet
                dtype=tf.float32,
                name="masked_loss__create_sequence_mask"
            )
            #print(mask)
            losses = loss_function(y_true, y_pred)
            #print(losses)
            losses = tf.math.multiply(losses, mask, name="masked_loss__apply_sequence_mask")
            #print(losses)
            summed_losses = tf.math.reduce_sum(losses, axis=-1, name="masked_loss__sum_losses")
            #print(summed_losses)
            average_losses = tf.math.divide_no_nan(summed_losses, lengths, name="masked_loss__average_losses")
            #print(average_losses)
            return average_losses
            
        super(MaskedLoss, self).__init__(_masked_loss, name=name, reduction=reduction)

In [None]:
masked_loss = MaskedLoss(
    loss_function=tf.keras.losses.sparse_categorical_crossentropy,
    masking_value=tf.constant(
        value=char_to_idx[PADDING_CHARACTERS[SEQ]],
        dtype=tf.int8
    )
)

In [None]:
x = masked_input_layers_list[0]
for input_layer in masked_input_layers_list[1:]:
    x = x + input_layer

x = tf.keras.layers.Flatten(name="flattened_masked_inputs")(x)

for _ in range(4):
    x = tf.keras.layers.Dense(2**11)(x)
    x = tf.keras.layers.Dropout(0.1)(x)

x = tf.keras.layers.Dense(PADDING_LENGTHS[SEQ]*len(ALPHABET))(x)

x = tf.reshape(x,(-1, PADDING_LENGTHS[SEQ], len(ALPHABET)))

x = tf.keras.activations.softmax(x)

model = tf.keras.Model(inputs=named_input_layers_list, outputs=x, name='mmproteo')
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=masked_loss,
              metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(),
                  tf.keras.metrics.SparseCategoricalCrossentropy()
              ]
             )
model.summary()

## Training the Tensorflow Model

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
TENSORBOARD_DIR = os.path.join(DUMP_PATH, "tensorboard")
TENSORBOARD_LOG_DIR = os.path.join(TENSORBOARD_DIR, "logs")

In [None]:
# Clear any logs from previous runs
try:
    shutil.rmtree(TENSORBOARD_DIR)
except FileNotFoundError:
    pass

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=os.path.join(TENSORBOARD_LOG_DIR, "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")), 
    histogram_freq=1
)

In [None]:
%tensorboard --logdir $TENSORBOARD_LOG_DIR

In [None]:
model.fit(x=merged_datasets[TRAINING_TYPE].repeat(),
          validation_data=merged_datasets[TEST_TYPE].repeat(), 
          validation_steps=500,
          epochs=30,
          steps_per_epoch=10_000,
          callbacks=[tensorboard_callback]
         )

## Evaluating the Tensorflow Model

In [None]:
model.evaluate(merged_datasets[EVAL_TYPE].repeat(), steps=int(40000/BATCH_SIZE))

In [None]:
def unzip(tuple_list: Iterable[Tuple[Any, Any]]) -> Tuple[Iterable[Any], Iterable[Any]]:
    return tuple(zip(*tuple_list))

In [None]:
SEPARATOR = " "
PREDICTED = "predicted"
TRUE = "true"

In [None]:
def decode_onehot(array: np.ndarray) -> np.ndarray:
    return np.argmax(array, axis=-1)

decode_idx: Callable[[np.ndarray], np.ndarray] = np.vectorize(idx_to_char.get)

def concat_letter_rows(array: np.ndarray) -> np.ndarray:
    return np.apply_along_axis(lambda row: SEPARATOR.join(row), axis=-1, arr=array)

def decode(array: np.ndarray, onehot: bool = True):
    if onehot:
        array = decode_onehot(array)
    array = decode_idx(array)
    array = concat_letter_rows(array)
    if not onehot:
        array = np.apply_along_axis(lambda row: row[0], axis=-1, arr=array)
    return array

In [None]:
eval_ds = merged_datasets[EVAL_TYPE].unbatch().batch(1).take(20)

x_eval, y_eval = unzip(eval_ds.as_numpy_iterator())
y_pred = model.predict(eval_ds)

# although the strings look like they have different lengths, they all have the same length
eval_df = pd.DataFrame(data=zip(decode(y_pred), decode(y_eval, onehot=False)), columns=[PREDICTED, TRUE])

eval_df[PREDICTED] = eval_df[PREDICTED].combine(
    other=eval_df[TRUE].str.rstrip(PADDING_CHARACTERS[SEQ] + SEPARATOR).str.split(SEPARATOR).str.len() + 1,
    func=lambda seq, length: SEPARATOR.join(seq.split(SEPARATOR)[:length])
)

#eval_df = eval_df.applymap(lambda s: s.replace(SEPARATOR, ""))

eval_df[TRUE] = eval_df[TRUE].str.rstrip(PADDING_CHARACTERS[SEQ] + SEPARATOR)

eval_df

In [None]:
eval_df.predicted.map(print)
None

In [None]:
y_eval[:1], y_pred[:1]