# Training an ML Model on Tensorflow Datasets
## Prerequisites

In [1]:
import json
import os
from typing import Callable

import numpy as np
import pandas as pd
import tensorflow as tf
from mmproteo.utils import log, paths, visualization
from mmproteo.utils.formats.tf_dataset import DatasetLoader
from mmproteo.utils.ml import callbacks, evaluation, layers, losses

In [2]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 1000)

In [3]:
logger = log.DummyLogger(verbose=False)

INFO: Printing to Stdout


## Configuration

In [4]:
%pwd

'/tf/workspace/notebooks'

In [5]:
PROJECT = "PXD010000"
DUMP_PATH = os.path.join("..", "dumps", PROJECT)
TRAINING_COLUMNS_DUMP_PATH = os.path.join(DUMP_PATH, "training_columns")
FILES_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "*_mzmlid.parquet")
STATISTICS_FILE_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "statistics.parquet")
DATASET_DUMP_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "tf_datasets")
PROCESSING_FILE_PATH = os.path.join(DATASET_DUMP_PATH, "processing_info.json")

In [6]:
SEQ = 'peptide_sequence'

In [7]:
with open(PROCESSING_FILE_PATH, 'r') as file:
    PROCESSING_INFO = json.loads(file.read())
PROCESSING_INFO

{'padding_characters': {'peptide_sequence': '_',
  'mz_array': 0.0,
  'intensity_array': 0.0},
 'padding_lengths': {'mz_array': 2354,
  'intensity_array': 2354,
  'peptide_sequence': 50},
 'idx_to_char': {'0': 'A',
  '1': 'C',
  '2': 'D',
  '3': 'E',
  '4': 'F',
  '5': 'G',
  '6': 'H',
  '7': 'I',
  '8': 'K',
  '9': 'L',
  '10': 'M',
  '11': 'M(Oxidation)',
  '12': 'N',
  '13': 'P',
  '14': 'Q',
  '15': 'R',
  '16': 'S',
  '17': 'T',
  '18': 'V',
  '19': 'W',
  '20': 'Y',
  '21': '_'},
 'normalization': {'intensity_array': '<function base_peak_normalize at 0x7fa6046d5158>'},
 'split_value_columns': ['species', 'istrain'],
 'training_data_columns': ['mz_array', 'intensity_array'],
 'target_data_columns': ['peptide_sequence']}

In [8]:
idx_to_char = {int(idx): char for idx, char in PROCESSING_INFO["idx_to_char"].items()}
char_to_idx = {char: idx for idx, char in idx_to_char.items()}

## Loading Tensorflow Datasets

In [9]:
KEEP_CACHE = True

In [10]:
TRAINING_TYPE = 'Train'
TEST_TYPE = 'Test'
EVAL_TYPE = 'Eval'

In [11]:
dataset_file_paths = paths.assign_wildcard_paths_to_splits_grouped_by_path_position_value(
    wildcard_path = os.path.join(
        DATASET_DUMP_PATH, 
        '*',  # filename
        '*',  # species
        '*'   # istrain
    ),
    path_position = -2,
    splits = {
            TRAINING_TYPE: 0.4,
            TEST_TYPE: 0.5,
            EVAL_TYPE: 0.6
        },
    paths_dump_file = os.path.join(
            DATASET_DUMP_PATH,
            "dataset_file_paths.json"
        ),
    skip_existing = KEEP_CACHE,
    logger = logger
)

print()
print("assigned dataset files:")
visualization.print_list_length_in_dict(dataset_file_paths)

INFO: found file paths dump '../dumps/PXD010000/training_columns/tf_datasets/dataset_file_paths.json'

assigned dataset files:
#Train = 89
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_C_indologenes_LIB_aerobic_02_03May16_Samwise_16-03-32_mzmlid.parquet/Chryseobacterium_indologenes/Train
#Test = 17
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_A_cryptum_FeTSB_anaerobic_1_01Jun16_Pippin_16-03-39_mzmlid.parquet/Acidiphilium_cryptum_JF-5/Train
#Eval = 29
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_B_fragilis_CMcarb_anaerobic_01_01Feb16_Arwen_15-07-13_mzmlid.parquet/Bacteroides_fragilis_638R/Train


### Loading corresponding TF datasets

In [12]:
element_spec = (
    tuple(tf.TensorSpec(shape=(PROCESSING_INFO['padding_lengths'][col], ), dtype=tf.float32)
     for col in PROCESSING_INFO['training_data_columns']),
    tuple(tf.TensorSpec(shape=(PROCESSING_INFO['padding_lengths'][col], ), dtype=tf.int8)
     for col in PROCESSING_INFO['target_data_columns'])
)
element_spec

((TensorSpec(shape=(2354,), dtype=tf.float32, name=None),
  TensorSpec(shape=(2354,), dtype=tf.float32, name=None)),
 (TensorSpec(shape=(50,), dtype=tf.int8, name=None),))

In [13]:
BATCH_SIZE=32

In [14]:
datasets = DatasetLoader(
    element_spec=element_spec,
    batch_size=BATCH_SIZE,
    shuffle_buffer_size=100_000,
    keep_cache=KEEP_CACHE,
    logger=logger
).load_datasets_by_type(dataset_file_paths)
datasets



{'Train': <BatchDataset shapes: (((32, 2354), (32, 2354)), ((32, 50),)), types: ((tf.float32, tf.float32), (tf.int8,))>,
 'Test': <BatchDataset shapes: (((32, 2354), (32, 2354)), ((32, 50),)), types: ((tf.float32, tf.float32), (tf.int8,))>,
 'Eval': <BatchDataset shapes: (((32, 2354), (32, 2354)), ((32, 50),)), types: ((tf.float32, tf.float32), (tf.int8,))>}

## Building the Tensorflow Model

In [15]:
input_layers_list, masked_input_layers_list = layers.create_masked_input_layers(
    [
        layers.InputLayerConfiguration(
            name=col,
            shape=PROCESSING_INFO['padding_lengths'][col],
            mask_value=PROCESSING_INFO['padding_characters'][col]
        )
        for col in PROCESSING_INFO['training_data_columns']
    ]
)
print(input_layers_list)
print(masked_input_layers_list)

[<KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'mz_array')>, <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'intensity_array')>]
[<KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'masked_mz_array')>, <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'masked_intensity_array')>]


In [16]:
masked_loss = losses.MaskedLoss(
    loss_function=tf.keras.losses.sparse_categorical_crossentropy,
    masking_value=tf.constant(
        value=char_to_idx[PROCESSING_INFO['padding_characters'][SEQ]],
        dtype=tf.int8
    )
)

In [17]:
x = tf.stack(
    values=masked_input_layers_list, 
    axis=-1,
)

x = tf.keras.layers.Flatten(name="flattened_masked_inputs")(x)

for _ in range(4):
    x = tf.keras.layers.Dense(2**11)(x)
    #x = tf.keras.layers.Dropout(0.1)(x)

x = tf.keras.layers.Dense(PROCESSING_INFO['padding_lengths'][SEQ]*len(idx_to_char))(x)

x = tf.reshape(x,(-1, PROCESSING_INFO['padding_lengths'][SEQ], len(idx_to_char)))

x = tf.keras.activations.softmax(x)

model = tf.keras.Model(inputs=input_layers_list, outputs=x, name=f"mmproteo_dense_{utils.get_current_time_str()}")
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=masked_loss,
              metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(),
                  tf.keras.metrics.SparseCategoricalCrossentropy()
              ]
             )

Model: "mmproteo"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
mz_array (InputLayer)           [(None, 2354)]       0                                            
__________________________________________________________________________________________________
intensity_array (InputLayer)    [(None, 2354)]       0                                            
__________________________________________________________________________________________________
masked_mz_array (Masking)       (None, 2354)         0           mz_array[0][0]                   
__________________________________________________________________________________________________
masked_intensity_array (Masking (None, 2354)         0           intensity_array[0][0]            
___________________________________________________________________________________________

In [None]:
MODEL_PATH = os.path.join(DUMP_PATH, "models", model.name)
MODEL_PATH

In [None]:
utils.ensure_dir_exists(MODEL_PATH)

In [None]:
tf.keras.utils.plot_model(
    model=model,
    to_file=os.path.join(MODEL_PATH, "model.png"),
    show_shapes=True
)

In [None]:
with open(os.path.join(MODEL_PATH, "summary.txt"), 'w') as file:
    def write_lines(line: str) -> None:
        file.write(line)
        file.write("\n")
    model.summary(print_fn=write_lines)
model.summary()

In [None]:
with open(os.path.join(MODEL_PATH, "model.json"), 'w') as file:
    file.write(model.to_json())

In [None]:
with open(os.path.join(MODEL_PATH, "model.yaml"), 'w') as file:
    file.write(model.to_yaml())

## Training the Tensorflow Model

In [18]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [19]:
TENSORBOARD_LOG_DIR = os.path.join(DUMP_PATH, "tensorboard", "logs")
os.path.realpath(TENSORBOARD_LOG_DIR)

'/tf/workspace/dumps/PXD010000/tensorboard/logs'

In [20]:
%tensorboard --logdir $TENSORBOARD_LOG_DIR --bind_all

Reusing TensorBoard on port 6006 (pid 109), started 1 day, 0:11:47 ago. (Use '!kill 109' to kill it.)

In [21]:
model.fit(x=datasets[TRAINING_TYPE].repeat(),
          validation_data=datasets[TEST_TYPE].repeat(), 
          validation_steps=500,
          epochs=1,
          steps_per_epoch=1_000,
          callbacks=[
              callbacks.create_tensorboard_callback(
                  tensorboard_log_dir = TENSORBOARD_LOG_DIR,
                  keep_logs = KEEP_CACHE
              )
          ]
         )



<tensorflow.python.keras.callbacks.History at 0x7fccc1fce4e0>

## Evaluating the Tensorflow Model

In [22]:
decode_idx: Callable[[np.ndarray], np.ndarray] = np.vectorize(idx_to_char.get)

In [23]:
evaluator = evaluation.SequenceEvaluator(
    dataset=datasets[EVAL_TYPE],
    decode_func=decode_idx,
    batch_size=BATCH_SIZE,
    separator=" ",
    padding_character=PROCESSING_INFO['padding_characters'][SEQ],
)

In [24]:
evaluator.evaluate_model(model)



[15.40013313293457, 0.033735498785972595, 11811841.0]

In [25]:
eval_df, (x_eval, y_eval, y_pred) = evaluator.evaluate_model_visually(
    model=model,
    sample_size=20,
    keep_separator=True,
)
eval_df

Unnamed: 0,predicted,true
0,W I D F K A K P Y Q R M Q,V K A G E V I V K I P R
1,W I E F K A K P Y Q I M Q Q Q G W S Y T P A,Y G N T H V V A G Y I H G A E P K G E L K
2,W I D F K A K P Y Q I,N G L G I V Y V Q R
3,W I D F K A K P Y Q I M,I V H T V I E S D R R
4,W I D F K A K P Y Q R M Q Q Q G W S Y T P T,S Y G N H I I E P A S G E L A S H L V G K
5,W I E F K A K P,M(Oxidation) I Q V E S R
6,W I E F K A K P Y Q I M Q,S I A G E V K P D F A K
7,W S D F K Y K P R Q R Y Q Q Q G Y S Y T P T L,T L H P D N M(Oxidation) A A G P A S Y G M T D T M G R
8,W I E F K A K P Y Q I M Q Q Q G W S Y T P A,T E T G A P V V T K Q P T A S K K D G V K
9,W I A F K A K P Y,I G I F Y A A K


In [26]:
eval_df.predicted.map(print)
None

W I D F K A K P Y Q R M Q
W I E F K A K P Y Q I M Q Q Q G W S Y T P A
W I D F K A K P Y Q I
W I D F K A K P Y Q I M
W I D F K A K P Y Q R M Q Q Q G W S Y T P T
W I E F K A K P
W I E F K A K P Y Q I M Q
W S D F K Y K P R Q R Y Q Q Q G Y S Y T P T L
W I E F K A K P Y Q I M Q Q Q G W S Y T P A
W I A F K A K P Y
W I E F K A K P Y Q I
W I A F K A K P Y Q I M Q Q Q G
W I E F K A K P Y Q I M Q Q Q G W S Y T P T L W M F N S
W I D F K A K P Y Q
W I D F K A K P Y Q I M Q Q Q G W S Y
W I E F K A K P Y Q I M Q Q
W I A F K A K P Y Q I M Q Q Q G W
W I E F K A K P Y Q I M Q Q
W I E F K A K P Y Q I M Q Q Q G W S Y T
W I E F K A K P Y Q I M Q Q Q G W S Y T P T L W
