# Prototyping an ML Model on Tensorflow Datasets
## Prerequisites

In [1]:
import datetime
import gc
import glob
import json
import os
import random
import shutil
import time
from typing import Iterable, Callable, Dict, Any, Tuple, Optional, List, Union

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python import keras as K

from mmproteo.utils import log, utils, visualization
from mmproteo.utils.formats.mz import FilteringProcessor
from mmproteo.utils.formats.tf_dataset import Parquet2DatasetFileProcessor
from mmproteo.utils.processing import ItemProcessor

In [2]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 1000)

In [3]:
logger = log.DummyLogger(verbose=False)

INFO: Printing to Stdout


## Configuration

In [4]:
pwd

'/tf/workspace/notebooks'

In [5]:
PROJECT = "PXD010000"
DUMP_PATH = os.path.join("..", "dumps", PROJECT)
TRAINING_COLUMNS_DUMP_PATH = os.path.join(DUMP_PATH, "training_columns")
FILES_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "*_mzmlid.parquet")
STATISTICS_FILE_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "statistics.parquet")
DATASET_DUMP_PATH = os.path.join(TRAINING_COLUMNS_DUMP_PATH, "tf_datasets")

In [6]:
MZMLID_FILE_PATHS = glob.glob(FILES_PATH)
len(MZMLID_FILE_PATHS)

235

In [7]:
MZMLID_FILE_PATHS[0]

'../dumps/PXD010000/training_columns/Biodiversity_B_fragilis_01_28Jul15_Arwen_14-12-03_mzmlid.parquet'

In [8]:
df = pd.read_parquet(MZMLID_FILE_PATHS[1])
df.head(2)

Unnamed: 0,peptide_sequence,mz_array,intensity_array,species,istrain
21,"[C, K, P, T, S, P, G, R]","[102.0558, 115.05197, 116.971794, 119.907036, 129.1024, 136.06175, 152.05682, 157.84837, 159.22517, 171.11295, 175.119, 175.95169, 199.10796, 202.6932, 215.08527, 228.88432, 232.11212, 244.87819, 286.14047, 307.6665, 312.16718, 329.19223, 360.2081, 378.2132, 385.92047, 400.78918, 401.78973, 416.22388, 422.8325, 440.8446, 441.84528, 517.2766, 614.3271, 615.3258]","[723.529, 569.4288, 659.1485, 599.0097, 19982.768, 4909.943, 771.28937, 596.6283, 593.3602, 1262.0436, 868.29816, 581.3835, 721.64886, 752.1542, 2492.1565, 3854.2283, 1364.17, 615.11633, 746.43365, 1512.8475, 1474.3188, 1069.4283, 762.6549, 744.29315, 925.18164, 7245.0005, 2374.2295, 3248.2861, 4047.135, 21597.44, 5534.1826, 4359.906, 13269.387, 2903.926]",Citrobacter_freundii,Train
70,"[K, H, I, T, A, G, A, K]","[101.1075, 110.07151, 111.04457, 111.619194, 112.050735, 116.972084, 118.967834, 122.29705, 129.05539, 129.10248, 129.11131, 129.92657, 130.08653, 136.06192, 136.07182, 136.07652, 137.06726, 139.98817, 147.11304, 152.05687, 171.00543, 173.09312, 189.01633, 197.12833, 200.14093, 212.10458, 218.14975, 223.15533, 230.11382, 231.12407, 232.88867, 239.08455, 249.13492, 251.15112, 275.1718, 283.13745, 299.95496, 301.1428, 302.81696, 309.96753, 313.8611, 315.81067, 316.8158, 318.8151, 334.8159, 335.81232, 336.81036, 337.8101, 340.79953, 343.80972, 344.8009, 346.20868, 349.20425, 354.82224, 355.82004, 360.81204, 361.81488, 362.81155, 363.81027, 370.8382, 372.79752, 389.83908, 394.83862, 407.8483, 408.7495, 408.8483, 412.79782, 412.8495, 413.26614, 413.85025, 414.26913, 419.80457, 430.79797, 431.7977, 447.25662, 448.2613, 465.94623, 560.34186, 561.3432, 697.4013, 787.23486]","[1244.104, 18248.63, 747.18225, 672.4936, 3284.768, 5824.9575, 1207.1666, 563.56824, 1090.989, 18666.379, 1132.0656, 547.7189, 8010.2773, 9944.686, 717.7685, 909.038, 927.90424, 1259.5803, 9798.942, 12360.792, 777.4666, 711.51215, 1365.1267, 669.6005, 718.6803, 724.85455, 1516.5447, 6849.315, 1172.2983, 11597.979, 882.7782, 954.53986, 1087.7533, 5462.294, 3395.8171, 717.6081, 663.439, 7134.955, 748.11066, 1207.3207, 3609.01, 838.3727, 1179.3096, 1473.9382, 2907.0327, 3263.6355, 4049.7156, 4270.7646, 793.22906, 1597.9222, 4802.7974, 4149.7407, 6089.5537, 7634.4062, 5610.0933, 1050.6061, 957.7547, 6195.684, 1396.489, 866.404, 846.26697, 1433.4541, 1076.0883, 5400.0293, 1063.56, 1220.6185, 1581.0791, 21550.523, 21930.604, 7990.5386, 3053.2961, 754.4101, 2840.6213, 1415.5275, 9367.521, 1103.6198, 957.9087, 6343.26, 795.6793, 2594.3503, 750.4295]",Citrobacter_freundii,Train


In [9]:
del df

In [10]:
SEQ = 'peptide_sequence'
MZ = 'mz_array'
INT = 'intensity_array'

In [11]:
TRAINING_DATA_COLUMNS = [MZ, INT]
TARGET_DATA_COLUMNS = [SEQ]
SPLIT_VALUE_COLUMNS = ['species', 'istrain']

## Calculating Statistics over all MZMLID Files

In [12]:
file_path_count = len(MZMLID_FILE_PATHS)

def get_mzmlid_file_stats(item: Tuple[int, str]) -> Dict[str, Any]:
    idx, path = item
    info_text = f"Processing item {idx + 1}/{file_path_count} '{path}'"
    if idx % 10 == 0:
        logger.info(info_text)
    else:
        logger.debug(info_text)
    df = pd.read_parquet(path)
    max_sequence_length = df[SEQ].str.len().max()
    max_array_length = df[INT].str.len().max()
    alphabet = set.union(*df[SEQ].apply(set))
    item_count = len(df)
    del df
    gc.collect()
    
    return {
        "file_path": path,
        "max_sequence_length": max_sequence_length,
        "max_array_length": max_array_length,
        "alphabet": alphabet,
        "item_count": item_count
    }

if os.path.exists(STATISTICS_FILE_PATH):
    file_stats = pd.read_parquet(STATISTICS_FILE_PATH)
    file_stats.alphabet = file_stats.alphabet.apply(set)
    print(f"loaded previous statistics file '{STATISTICS_FILE_PATH}'")
else:
    file_stats = pd.DataFrame(
        ItemProcessor(
            items=enumerate(MZMLID_FILE_PATHS),
            item_processor=get_mzmlid_file_stats,
            action_name="analyse",
            subject_name="mzmlid file",
            thread_count=0,
            logger=logger
        ).process()
    )
    
    file_stats_writable = file_stats.copy()
    file_stats_writable.alphabet = file_stats_writable.alphabet.apply(list) # cannot store sets
    file_stats_writable.to_parquet(STATISTICS_FILE_PATH)

loaded previous statistics file '../dumps/PXD010000/training_columns/statistics.parquet'


In [13]:
file_stats.head(2)

Unnamed: 0,file_path,max_sequence_length,max_array_length,alphabet,item_count
0,../dumps/PXD010000/training_columns/Biodiversity_B_fragilis_01_28Jul15_Arwen_14-12-03_mzmlid.parquet,50,1845,"{Y, H, R, L, M(Oxidation), W, C, G, E, I, F, S, K, Q, P, N, D, M, A, V, T}",26943
1,../dumps/PXD010000/training_columns/Biodiversity_Cibrobacter_freundii_LB_aerobic_01_01Feb16_Arwen_15-07-13_mzmlid.parquet,50,1697,"{Y, H, R, L, M(Oxidation), W, C, G, E, I, F, S, K, Q, P, N, D, M, A, V, T}",27516


In [14]:
PADDING_LENGTHS = {
    MZ: file_stats.max_array_length.max(),
    INT: file_stats.max_array_length.max(),
    SEQ: file_stats.max_sequence_length.max()
}

In [15]:
print("padding lengths =", PADDING_LENGTHS)

TOTAL_ITEM_COUNT = file_stats.item_count.sum()
print(f"TOTAL_ITEM_COUNT = {TOTAL_ITEM_COUNT}")

ALPHABET = set.union(*file_stats.alphabet)
print(f"ALPHABET = {', '.join(sorted(ALPHABET))}")

padding lengths = {'mz_array': 2354, 'intensity_array': 2354, 'peptide_sequence': 50}
TOTAL_ITEM_COUNT = 5513185
ALPHABET = A, C, D, E, F, G, H, I, K, L, M, M(Oxidation), N, P, Q, R, S, T, V, W, Y


## Data Normalization, Padding, and Conversion to Tensorflow Datasets

In [16]:
def l2_normalize(values: np.ndarray) -> np.ndarray:
    return tf.keras.utils.normalize(x=values, order=2)

def base_peak_normalize(values: np.ndarray) -> np.ndarray:
    return values / values.max(initial=0)

# by Tom, probably
# don't know, what it's based on
def ion_current_normalize(intensities: np.ndarray) -> np.ndarray:
    total_sum = np.sum(intensities**2)
    normalized = intensities/total_sum
    return normalized

NORMALIZATION = {
    INT: base_peak_normalize
}

In [17]:
PADDING_CHARACTERS = {
    SEQ: '_',
    MZ: 0.0,
    INT: 0.0,
}

ALPHABET.add(PADDING_CHARACTERS[SEQ])

In [18]:
char_to_idx = {char: idx for idx, char in enumerate(sorted(ALPHABET))}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
INDEX_ALPHABET = idx_to_char.keys()
char_to_idx

{'A': 0,
 'C': 1,
 'D': 2,
 'E': 3,
 'F': 4,
 'G': 5,
 'H': 6,
 'I': 7,
 'K': 8,
 'L': 9,
 'M': 10,
 'M(Oxidation)': 11,
 'N': 12,
 'P': 13,
 'Q': 14,
 'R': 15,
 'S': 16,
 'T': 17,
 'V': 18,
 'W': 19,
 'Y': 20,
 '_': 21}

In [19]:
Parquet2DatasetFileProcessor(
    training_data_columns=TRAINING_DATA_COLUMNS,
    target_data_columns=TARGET_DATA_COLUMNS,
    padding_lengths=PADDING_LENGTHS,
    padding_characters=PADDING_CHARACTERS,
    column_normalizations=NORMALIZATION,
    dataset_dump_path_prefix=DATASET_DUMP_PATH,
    char_to_idx_mapping_functions={
        SEQ: char_to_idx.get
    },
    item_count=len(MZMLID_FILE_PATHS),
    skip_existing=True,
    split_on_column_values_of=SPLIT_VALUE_COLUMNS,
    logger=logger
).process(parquet_file_paths=MZMLID_FILE_PATHS,
          thread_count=3)[:3]

INFO: Processing item 1/235: '../dumps/PXD010000/training_columns/Biodiversity_B_fragilis_01_28Jul15_Arwen_14-12-03_mzmlid.parquet'
INFO: Processing item 11/235: '../dumps/PXD010000/training_columns/Biodiversity_P_polymyxa_TBS_aerobic_3_17July16_Samwise_16-04-10_mzmlid.parquet'
INFO: Processing item 21/235: '../dumps/PXD010000/training_columns/M_alcali_copp_CH4_B2_T1_09_QE_23Mar18_Oak_18-01-07_mzmlid.parquet'
INFO: Processing item 31/235: '../dumps/PXD010000/training_columns/Cj_media_MH_R4_23Feb15_Arwen_14-12-03_mzmlid.parquet'
INFO: Processing item 41/235: '../dumps/PXD010000/training_columns/Biodiversity_C_Baltica_T240_R2_C_27Jan16_Arwen_15-07-13_mzmlid.parquet'
INFO: Processing item 51/235: '../dumps/PXD010000/training_columns/Biodiversity_M_xanthus_DZ2_plates_1_03May16_Samwise_16-03-32_mzmlid.parquet'
INFO: Processing item 61/235: '../dumps/PXD010000/training_columns/Biodiversity_B_thet_CMgluc_anaerobic_02_01Feb16_Arwen_15-07-13_mzmlid.parquet'
INFO: Processing item 71/235: '../dum

[]

## Loading Tensorflow Datasets

### ... by type

#### ... by training type annotation (abandoned)

In [20]:
TRAINING_DATA_TYPES = {path.split(os.path.sep)[-1] for path in glob.glob(
    os.path.join(
        DATASET_DUMP_PATH, 
        '*',  # filename
        '*',  # species
        '*'   # istrain
    ))}
TRAINING_DATA_TYPES

{'Train'}

#### ... by species annotation

In [21]:
SPECIES = {path.split(os.path.sep)[-2] for path in glob.glob(
    os.path.join(
        DATASET_DUMP_PATH, 
        '*',  # filename
        '*',  # species
        '*'   # istrain
    ))}
SPECIES

{'Acidiphilium_cryptum_JF-5',
 'Agrobacterium_tumefaciens_IAM_12048',
 'Alcaligenes_faecalis',
 'Algoriphagus_marincola_HL-49',
 'Anaerococcus_hydrogenalis_DSM_7454',
 'Bacillus_cereus_ATCC14579',
 'Bacillus_subtilis_168',
 'Bacillus_subtilis_NCIB3610',
 'Bacteroides_fragilis_638R',
 'Bacteroides_thetaiotaomicron_VPI-5482',
 'Bifidobacterium_bifidum_ATCC29521',
 'Bifidobacterium_longum_infantis_ATCC15697',
 'Campylobacter_jejuni',
 'Cellulomonas_gilvus_ATCC13127',
 'Cellulophaga_baltica_18',
 'Chryseobacterium_indologenes',
 'Citrobacter_freundii',
 'Clostridium_ljungdahlii_DMS_13528',
 'Coprococcus_comes_ATCC27758',
 'Cupriavidus_necator_N-1',
 'Cyanobacterium_stanieri',
 'Delftia_acidovorans_SPH1',
 'Dorea_longicatena_DSM13814',
 'Erythrobacter_HL-111',
 'Faecalibacterium_prausnitzii',
 'Fibrobacter_succinogenes_S85',
 'Francisella_novicida_U112',
 'Halomonas_HL-48',
 'Halomonas_HL-93',
 'Lactobacillales_casei',
 'Legionella_pneumophila',
 'Listeria_monocytogenes_10403S',
 'Methylomi

In [22]:
len(SPECIES)

51

### ... with train-test-eval split

In [23]:
SPECIES_SPLITS = {
            "Train": 0.4,
            "Test": 0.5,
            "Eval": 0.6
        }
KEEP_CACHE = True  # currently, there is no cache; the flag only disables benchmarking

In [24]:
def assign_species_randomly(species: List[str], splits: Dict[str, float] = None) -> Dict[str, List[str]]:
    if splits is None:
        splits = {
            "Train": 0.8,
            "Test": 0.94,
            "Eval": 1.0
        }
    
    splits[None] = 0
    
    sorted_splits = sorted(splits.items(), key=lambda tupl: tupl[1])
    
    shuffled_species = list(species)
    random.shuffle(shuffled_species)

    assigned_species = {
        training_data_type: shuffled_species[
            int(sorted_splits[i][1] * len(shuffled_species)):
            int(split * len(shuffled_species))
        ]
        for i, (training_data_type, split) in enumerate(sorted_splits[1:])
    }
    return assigned_species

def flatten(lists: List[List[Any]]) -> List[Any]:
    res = []
    for item in lists:
        res += item
    return res

def find_files_for_assigned_species(
    assigned_species: Dict[str, List[str]],
    file_pattern: str = os.path.join(DATASET_DUMP_PATH, '*', "{specie}", '*')
) -> Dict[str, List[str]]:
    return {
        training_type: flatten(
            [
                glob.glob(file_pattern.format(specie=specie)) for specie in species
            ]
        ) for training_type, species in assigned_species.items()
    }

def store_dataset_file_paths(
    dataset_file_paths: str, 
    output_file) -> str:
    with open(output_file, 'w') as file:
        file.write(visualization.pretty_print_json(dataset_file_paths))

    return output_file

def load_json(file_path: str) -> Dict[str, Any]:
    with open(file_path, 'r') as file:
        return json.loads(file.read())

def print_list_length_in_dict(dic: Dict[str, List[Any]]) -> None:
    for key, list_value in dic.items():
        print(f"#{key} = {len(list_value)}")
        if len(list_value) > 0:
            print(f"e.g.: {list_value[0]}")

dataset_file_path_dump_file = os.path.join(
        DATASET_DUMP_PATH, 
        "dataset_file_paths.json"
    )

if KEEP_CACHE and os.path.exists(dataset_file_path_dump_file):
    dataset_file_paths = load_json(dataset_file_path_dump_file)
    print(f"found dataset file paths dump '{dataset_file_path_dump_file}'")
else:
    assigned_species = assign_species_randomly(SPECIES, splits=SPECIES_SPLITS)
    print("assigned species:")
    print_list_length_in_dict(assigned_species)

    dataset_file_paths = find_files_for_assigned_species(assigned_species)
    store_dataset_file_paths(dataset_file_paths, dataset_file_path_dump_file)
    print(f"dumped dataset file paths into '{dataset_file_path_dump_file}'")

print()
print("assigned dataset files:")
print_list_length_in_dict(dataset_file_paths)

found dataset file paths dump '../dumps/PXD010000/training_columns/tf_datasets/dataset_file_paths.json'

assigned dataset files:
#Train = 89
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_C_indologenes_LIB_aerobic_02_03May16_Samwise_16-03-32_mzmlid.parquet/Chryseobacterium_indologenes/Train
#Test = 17
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_A_cryptum_FeTSB_anaerobic_1_01Jun16_Pippin_16-03-39_mzmlid.parquet/Acidiphilium_cryptum_JF-5/Train
#Eval = 29
e.g.: ../dumps/PXD010000/training_columns/tf_datasets/Biodiversity_B_fragilis_CMcarb_anaerobic_01_01Feb16_Arwen_15-07-13_mzmlid.parquet/Bacteroides_fragilis_638R/Train


### Loading corresponding TF datasets

In [25]:
element_spec = ((tf.TensorSpec(shape=(PADDING_LENGTHS[MZ],), dtype=tf.float32), 
  tf.TensorSpec(shape=(PADDING_LENGTHS[INT],), dtype=tf.float32)),
(tf.TensorSpec(shape=(PADDING_LENGTHS[SEQ],), dtype=tf.int8)))
element_spec

((TensorSpec(shape=(2354,), dtype=tf.float32, name=None),
  TensorSpec(shape=(2354,), dtype=tf.float32, name=None)),
 TensorSpec(shape=(50,), dtype=tf.int8, name=None))

In [26]:
merged_datasets = {
    training_data_type: tf.data.Dataset.from_tensor_slices(paths).interleave(lambda path: 
        tf.data.experimental.load(
            path=path, 
            element_spec=element_spec, 
            compression='GZIP'
        ),
                                                                             num_parallel_calls=os.cpu_count(),
                                                                             deterministic=False
                                                                            )
    for training_data_type, paths in dataset_file_paths.items()
}

merged_datasets

{'Train': <ParallelInterleaveDataset shapes: (((2354,), (2354,)), (50,)), types: ((tf.float32, tf.float32), tf.int8)>,
 'Test': <ParallelInterleaveDataset shapes: (((2354,), (2354,)), (50,)), types: ((tf.float32, tf.float32), tf.int8)>,
 'Eval': <ParallelInterleaveDataset shapes: (((2354,), (2354,)), (50,)), types: ((tf.float32, tf.float32), tf.int8)>}

## Configuring Tensorflow Datasets

In [27]:
BATCH_SIZE = 32

# although not all data fits into a 100k buffer, the interleaving should make it sufficiently random
SHUFFLE_BUFFER_SIZE = 2*10**5

### Caching (currently abandoned because of too high RAM usage)

### Preloading

In [28]:
def fill_cache(dataset, name: Optional[str] = None):
    """
    Use a benchmark to once process the whole dataset.
    """
    if name is not None:
        print(f"{name}:")
    display(tfds.benchmark(dataset))
    gc.collect()
    logger.info("filled a cache - waiting 10 seconds")
    print()
    time.sleep(10)
    return dataset

In [29]:
if not KEEP_CACHE:
    merged_datasets = {
        training_data_type: fill_cache(dataset, name=training_data_type)
        for training_data_type, dataset in merged_datasets.items()
    }

### Shuffling, Batching, Prefetching

In [30]:
merged_datasets = {
    training_data_type: dataset
        .shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True)
        .batch(BATCH_SIZE, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    for training_data_type, dataset in merged_datasets.items()
}
merged_datasets

{'Train': <PrefetchDataset shapes: (((32, 2354), (32, 2354)), (32, 50)), types: ((tf.float32, tf.float32), tf.int8)>,
 'Test': <PrefetchDataset shapes: (((32, 2354), (32, 2354)), (32, 50)), types: ((tf.float32, tf.float32), tf.int8)>,
 'Eval': <PrefetchDataset shapes: (((32, 2354), (32, 2354)), (32, 50)), types: ((tf.float32, tf.float32), tf.int8)>}

In [31]:
TRAINING_TYPE = 'Train'
TEST_TYPE = 'Test'
EVAL_TYPE = 'Eval'

## Building the Tensorflow Model

In [32]:
named_input_layers = {col: tf.keras.layers.Input(shape=(PADDING_LENGTHS[col],), name=col) for col in TRAINING_DATA_COLUMNS}
named_input_layers

{'mz_array': <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'mz_array')>,
 'intensity_array': <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'intensity_array')>}

In [33]:
named_input_layers_list = [ named_input_layers[col] for col in TRAINING_DATA_COLUMNS ]
named_input_layers_list

[<KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'mz_array')>,
 <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'intensity_array')>]

In [34]:
masked_input_layers = {
    col: tf.keras.layers.Masking(mask_value=PADDING_CHARACTERS[col], name=f"masked_{col}")(input_layer)
    for col, input_layer in named_input_layers.items()
}
masked_input_layers

{'mz_array': <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'masked_mz_array')>,
 'intensity_array': <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'masked_intensity_array')>}

In [35]:
masked_input_layers_list = [ masked_input_layers[col] for col in TRAINING_DATA_COLUMNS ]
masked_input_layers_list

[<KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'masked_mz_array')>,
 <KerasTensor: shape=(None, 2354) dtype=float32 (created by layer 'masked_intensity_array')>]

In [36]:
class MaskedLoss(K.losses.LossFunctionWrapper):
    def __init__(self, loss_function, masking_value, name='masked_loss', reduction=tf.keras.losses.Reduction.NONE):
        def _masked_loss(y_true, y_pred):
            y_true = tf.squeeze(y_true, name="masked_loss__squeezed_y_true")
            y_pred = tf.squeeze(y_pred, name="masked_loss__squeezed_y_pred")
            #print(y_true)
            #print(y_pred)
            length_mask = tf.equal(y_true, masking_value, name="masked_loss__is_masking_value")
            #print(length_mask)
            length_mask = tf.cast(length_mask, tf.float32, name="masked_loss__is_masking_value_float")
            length_mask = tf.math.subtract(
                tf.constant(
                    value=1, 
                    dtype=tf.float32
                ), length_mask, name="masked_loss__is_masking_value_inverted")
            lengths = tf.math.reduce_sum(length_mask, axis=-1, name="masked_loss__sum_to_get_lengths")
            #print(lengths)
            lengths = tf.math.add(lengths, 1, name="masked_loss__sum_to_include_first_padding") # to also include the first padding character
            #print(lengths)
            mask = tf.sequence_mask(
                lengths=lengths,
                maxlen=y_pred.shape[-2],  # pre-last dimension = padding length; last dimension = one-hot-encoded alphabet
                dtype=tf.float32,
                name="masked_loss__create_sequence_mask"
            )
            #print(mask)
            losses = loss_function(y_true, y_pred)
            #print(losses)
            losses = tf.math.multiply(losses, mask, name="masked_loss__apply_sequence_mask")
            #print(losses)
            summed_losses = tf.math.reduce_sum(losses, axis=-1, name="masked_loss__sum_losses")
            #print(summed_losses)
            average_losses = tf.math.divide_no_nan(summed_losses, lengths, name="masked_loss__average_losses")
            #print(average_losses)
            return average_losses
            
        super(MaskedLoss, self).__init__(_masked_loss, name=name, reduction=reduction)

In [37]:
masked_loss = MaskedLoss(
    loss_function=tf.keras.losses.sparse_categorical_crossentropy,
    masking_value=tf.constant(
        value=char_to_idx[PADDING_CHARACTERS[SEQ]],
        dtype=tf.int8
    )
)

In [38]:
x = masked_input_layers_list[0]
for input_layer in masked_input_layers_list[1:]:
    x = x + input_layer

x = tf.keras.layers.Flatten(name="flattened_masked_inputs")(x)

for _ in range(4):
    x = tf.keras.layers.Dense(2**11)(x)
    x = tf.keras.layers.Dropout(0.1)(x)

x = tf.keras.layers.Dense(PADDING_LENGTHS[SEQ]*len(ALPHABET))(x)

x = tf.reshape(x,(-1, PADDING_LENGTHS[SEQ], len(ALPHABET)))

x = tf.keras.activations.softmax(x)

model = tf.keras.Model(inputs=named_input_layers_list, outputs=x, name='mmproteo')
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=masked_loss,
              metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(),
                  tf.keras.metrics.SparseCategoricalCrossentropy()
              ]
             )
model.summary()

Model: "mmproteo"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
mz_array (InputLayer)           [(None, 2354)]       0                                            
__________________________________________________________________________________________________
intensity_array (InputLayer)    [(None, 2354)]       0                                            
__________________________________________________________________________________________________
masked_mz_array (Masking)       (None, 2354)         0           mz_array[0][0]                   
__________________________________________________________________________________________________
masked_intensity_array (Masking (None, 2354)         0           intensity_array[0][0]            
___________________________________________________________________________________________

## Training the Tensorflow Model

In [39]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [40]:
TENSORBOARD_DIR = os.path.join(DUMP_PATH, "tensorboard")
TENSORBOARD_LOG_DIR = os.path.join(TENSORBOARD_DIR, "logs")

In [41]:
# Clear any logs from previous runs
try:
    shutil.rmtree(TENSORBOARD_DIR)
except FileNotFoundError:
    pass

In [42]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=os.path.join(TENSORBOARD_LOG_DIR, "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")), 
    histogram_freq=1
)

In [43]:
%tensorboard --logdir $TENSORBOARD_LOG_DIR

In [44]:
model.fit(x=merged_datasets[TRAINING_TYPE].repeat(),
          validation_data=merged_datasets[TEST_TYPE].repeat(), 
          validation_steps=500,
          epochs=30,
          steps_per_epoch=10_000,
          callbacks=[tensorboard_callback]
         )

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30


Epoch 29/30
Epoch 30/30


<tensorflow.python.keras.callbacks.History at 0x7f7864070a90>

## Evaluating the Tensorflow Model

In [45]:
model.evaluate(merged_datasets[EVAL_TYPE].repeat(), steps=int(40000/BATCH_SIZE))



[15.1504545211792, 0.021316999569535255, 17267476480.0]

In [46]:
def unzip(tuple_list: Iterable[Tuple[Any, Any]]) -> Tuple[Iterable[Any], Iterable[Any]]:
    return tuple(zip(*tuple_list))

In [47]:
SEPARATOR = " "
PREDICTED = "predicted"
TRUE = "true"

In [48]:
def decode_onehot(array: np.ndarray) -> np.ndarray:
    return np.argmax(array, axis=-1)

decode_idx: Callable[[np.ndarray], np.ndarray] = np.vectorize(idx_to_char.get)

def concat_letter_rows(array: np.ndarray) -> np.ndarray:
    return np.apply_along_axis(lambda row: SEPARATOR.join(row), axis=-1, arr=array)

def decode(array: np.ndarray, onehot: bool = True):
    if onehot:
        array = decode_onehot(array)
    array = decode_idx(array)
    array = concat_letter_rows(array)
    if not onehot:
        array = np.apply_along_axis(lambda row: row[0], axis=-1, arr=array)
    return array

In [53]:
eval_ds = merged_datasets[EVAL_TYPE].unbatch().batch(1).take(20)

x_eval, y_eval = unzip(eval_ds.as_numpy_iterator())
y_pred = model.predict(eval_ds)

# although the strings look like they have different lengths, they all have the same length
eval_df = pd.DataFrame(data=zip(decode(y_pred), decode(y_eval, onehot=False)), columns=[PREDICTED, TRUE])

eval_df[PREDICTED] = eval_df[PREDICTED].combine(
    other=eval_df[TRUE].str.rstrip(PADDING_CHARACTERS[SEQ] + SEPARATOR).str.split(SEPARATOR).str.len() + 1,
    func=lambda seq, length: SEPARATOR.join(seq.split(SEPARATOR)[:length])
)

#eval_df = eval_df.applymap(lambda s: s.replace(SEPARATOR, ""))

eval_df[TRUE] = eval_df[TRUE].str.rstrip(PADDING_CHARACTERS[SEQ] + SEPARATOR)

eval_df

Unnamed: 0,predicted,true
0,L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N V,V A I V D F S T E K P I I Y P N N G W K
1,L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N,Y I S S Y I P H N E E A Q M V S I S K
2,L A D V G G Y S A,Q V L D I V T K
3,L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N V A,S P G Y T R E E L F K E L A D L I V E I K
4,L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N V A A,I E T G V I H V G D E I E I L G L G E D K K
5,L A D V G G Y _ A I A F T,A K E D F L A D V A K R
6,L A D V G G Y S A I A F T M(Oxidation) G F W,R G F S N E I I E N I H N A Y R
7,L A D V G G Y _ A I A,A D L E K E V A L R
8,L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N,G Q T A F V S S N T N F V M(Oxidation) L N G Q R
9,L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N V A A W E E A,I H Q A V E Q M V E S L D M A A G S T F S F D L Y K


In [52]:
eval_df.predicted.map(print)
None

L A D V G G Y _ A I A F T M(Oxidation) G
L A D V G G Y S A I A F T M(Oxidation) G F W
L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N V A A
L A D V G G Y S A I A F T
L A D V G G Y _ A I
L A D V G G Y S
L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N
L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N V A A W E
L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N V A A W E E A
L A D V G G Y S A I
L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation)
L A D V G G Y _ A I A
L A D V G G Y S A I A F T
L A D V G G Y S A I
L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N V A A W E E A
L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N V A
L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V
L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation) V N V
L A D V G G Y _ A I A F T M(Oxidation) G F W M(Oxidation) V N V A A W E E A
L A D V G G Y S A I A F T M(Oxidation) G F W M(Oxidation)


In [51]:
y_eval[:1], y_pred[:1]

((array([[ 2,  0,  7, 14,  8,  2, 12, 15, 13,  3, 13,  0,  8, 15, 21, 21,
          21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
          21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
          21, 21]], dtype=int8),),
 array([[[0., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [0., 0., 1., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32))