Hello fellow Kagglers,

This notebook demonstrates the training process on a TPU in Tensorflow.

Thanks to the use of a [TPU (Tensor Processing Unit)](https://cloud.google.com/tpu) training takes about an hour.

The TFREcord dataset contains cropped images sized 1344x768, created in [this notebook](https://www.kaggle.com/code/markwijkhuizen/rsna-preprocessing-tfrecords-640x512-dataset).

20% of the data is used for validation, which reaches ~0.20 pF1 with the best threshold.

**Things that did not work for me:**

* [SigmoidFocalCrossEntropy](https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy)
* Increasing model size to for example EfficientNetV2S

**Things that did work for me:**

* Class weights: give minority class weight of 10
* Training on TPU instead of GPU: larger batch size (16x2->16x8) giving larger probability of having positive sample in batch
* Cropping Images
* Using Cropped Image Ratio

I enjoy this competition and will update this notebook frequently, stay tuned!

**V2**

* Cropped images in 1344x768 resolution
* EfficientNetV2T
* Added augmentations
* Single image modal instead of both CC and MLO views as input

[Inference Notebook](https://www.kaggle.com/markwijkhuizen/rsna-efficientnetv2-inference-tensorflow)

In [None]:
# The Kaggle Tensorflow version is old and does not contain EfficientNetV2: get it from pip package
!pip install -qq /kaggle/input/kerasefficientnetv2/keras_efficientnet_v2-1.2.2-py3-none-any.whl

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import matplotlib as mpl

from tqdm.notebook import tqdm
from multiprocessing import cpu_count
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split

import keras_efficientnet_v2
import os
import time
import pickle
import math
import random
import sys
import cv2
import gc

print(f'Tensorflow Version: {tf.__version__}')
print(f'Python Version: {sys.version}')

# Mixed Precision Policy

In [None]:
# float32 or mixed_float16 (mixed precision: compute float16, variable float32)
# TPU is fast enough and has enough memory to use float32
policy = tf.keras.mixed_precision.Policy('float32')
tf.keras.mixed_precision.set_global_policy(policy)

print(f'Compute dtype: {tf.keras.mixed_precision.global_policy().compute_dtype}')
print(f'Variable dtype: {tf.keras.mixed_precision.global_policy().variable_dtype}')

# Matplotlib Config

In [None]:
# MatplotLib Global Settings
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['axes.labelsize'] = 18
mpl.rcParams['axes.titlesize'] = 24

# Config

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', TPU.master())
except ValueError:
    print('Running on GPU')
    TPU = None

if TPU:
    IS_TPU = True
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    STRATEGY = tf.distribute.experimental.TPUStrategy(TPU)
else:
    IS_TPU = False
    STRATEGY = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

N_REPLICAS = STRATEGY.num_replicas_in_sync
print(f'N_REPLICAS: {N_REPLICAS}, IS_TPU: {IS_TPU}')

In [None]:
# For TPU's the dataset needs to be stored in Google Cloud
# Retrieve the Google Cloud location of the dataset
GCS_DS_PATH = KaggleDatasets().get_gcs_path('rsna-preprocessing-tfrecords-640x512-dataset-pub')

In [None]:
SEED = 43
DEBUG = False

# Image dimensions
IMG_HEIGHT = 1344
IMG_WIDTH = 768
N_CHANNELS = 1
INPUT_SHAPE = (IMG_HEIGHT, IMG_WIDTH, 1)
N_SAMPLES_TFRECORDS = 548

# Peak Learning Rate
LR_MAX = 8e-4

N_WARMUP_EPOCHS = 2
N_EPOCHS = 15

# Batch size
BATCH_SIZE = 8 * N_REPLICAS

# Is Interactive Flag and COrresponding Verbosity Method
IS_INTERACTIVE = os.environ['KAGGLE_KERNEL_RUN_TYPE'] == 'Interactive'
VERBOSE = 1 if IS_INTERACTIVE else 2

# Tensorflow AUTO flag
AUTO = tf.data.experimental.AUTOTUNE

print(f'BATCH_SIZE: {BATCH_SIZE}')

# Seed

In [None]:
# Seed all random number generators
def seed_everything(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_everything()

# Train

In [None]:
# Train DataFrame
train = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/train.csv')

display(train.head())
display(train.info())

# Utility Functions

In [None]:
# short Tensorflow randin integer function
def tf_rand_int(minval, maxval, dtype=tf.int64):
    minval = tf.cast(minval, dtype)
    maxval = tf.cast(maxval, dtype)
    return tf.random.uniform(shape=(), minval=minval, maxval=maxval, dtype=dtype)

# chance of 1 in k
def one_in(k):
    return 0 == tf_rand_int(0, k)

# Dataset

In [None]:
# Function to benchmark the dataset
def benchmark_dataset(dataset, num_epochs=3, n_steps_per_epoch=10, bs=BATCH_SIZE):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for idx, (inputs, labels) in enumerate(dataset.take(n_steps_per_epoch + 1)):
            if idx == 0:
                epoch_start = time.perf_counter()
            elif idx == 1 and epoch_num == 0:
                image = inputs['image']
                print(f'image shape: {image.shape}, labels shape: {labels.shape}, image dtype: {image.dtype}, labels dtype: {labels.dtype}')
            else:
                pass
        
        epoch_t = time.perf_counter() - epoch_start
        mean_step_t = round(epoch_t / n_steps_per_epoch * 1000, 1)
        n_imgs_per_s = int(1 / (mean_step_t / 1000) * bs)
        print(f'epoch {epoch_num} took: {round(epoch_t, 2)} sec, mean step duration: {mean_step_t}ms, images/s: {n_imgs_per_s}')

In [None]:
# Plots a batch of images
def show_batch(dataset, rows=16, cols=1):
    inputs, targets = next(iter(dataset))
    images = np.moveaxis(inputs['image'], 3, 1)
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*6, rows*6))
    for r in range(rows):
        for c in range(cols):
            img = images[r,c]
            axes[r].imshow(img)
            if c == 0:
                target = targets[r]
                axes[r].set_title(f'target: {target}', fontsize=12, pad=16)
        
    plt.show()

In [None]:
# Decodes the TFRecords
def decode_image(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.int64),
        'patient_id': tf.io.FixedLenFeature([], tf.int64),
    })
    
    # Decode PNG Image
    image = tf.io.decode_png(features['image'], channels=N_CHANNELS)
    # Explicit reshape needed for TPU
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, N_CHANNELS])

    target = features['target']
    
    return { 'image': image }, target

In [None]:
def augment_image(X, y):
    image = X['image']
    
    # Random Brightness
    image = tf.image.random_brightness(image, 0.10)
    
    # Random Contrast
    image = tf.image.random_contrast(image, 0.90, 1.10)
    
    # Random JPEG Quality
    image = tf.image.random_jpeg_quality(image, 75, 100)
    
    # Random crop image with maximum of 10%
    ratio = tf.random.uniform([], 0.75, 1.00)
    img_height_crop = tf.cast(ratio * IMG_HEIGHT, tf.int32)
    img_width_crop = tf.cast(ratio * IMG_WIDTH, tf.int32)
    # Random offset for crop
    img_height_offset = tf_rand_int(0, IMG_HEIGHT - img_height_crop)
    img_width_offset = 0
    # Crop And Resize
    image = tf.slice(image, [img_height_offset, img_width_offset, 0], [img_height_crop, img_width_crop, N_CHANNELS])
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH], method=tf.image.ResizeMethod.BILINEAR)
    # Clip pixel values in range [0,255] to prevent underflow/overflow
    image = tf.clip_by_value(image, 0, 255)
    image = tf.cast(image, tf.uint8)
    
    return { 'image': image }, y

In [None]:
# Undersample majority class (0/negative) by randomly dropping them
def undersample_majority(X, y):
    # Filter 2/3 of negative samples to upsample positive samples by a factor 3
    return y == 1 or tf.random.uniform([]) > 0.66

In [None]:
# TFRecord file paths
TFRECORDS_FILE_PATHS = sorted(tf.io.gfile.glob(f'{GCS_DS_PATH}/*.tfrecords'))
print(f'Found {len(TFRECORDS_FILE_PATHS)} TFRecords')

In [None]:
# Train Test Split
TFRECORDS_TRAIN, TFRECORDS_VAL = train_test_split(TFRECORDS_FILE_PATHS, train_size=0.80, random_state=SEED, shuffle=True)
print(f'# TFRECORDS_TRAIN: {len(TFRECORDS_TRAIN)}, # TFRECORDS_VAL: {len(TFRECORDS_VAL)}')

In [None]:
def get_dataset(tfrecords, bs=BATCH_SIZE, val=False, debug=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    
    # Initialize dataset with TFRecords
    dataset = tf.data.TFRecordDataset(tfrecords, num_parallel_reads=AUTO, compression_type='GZIP')
    
    # Decode mapping
    dataset = dataset.map(decode_image, num_parallel_calls=AUTO)

    if not val:
        dataset = dataset.filter(undersample_majority)
        dataset = dataset.map(augment_image, num_parallel_calls=AUTO)
        dataset = dataset.with_options(ignore_order)
        if not debug:
            dataset = dataset.shuffle(1024)
        dataset = dataset.repeat()        

    dataset = dataset.batch(bs, drop_remainder=not val)
    dataset = dataset.prefetch(AUTO)
    
    return dataset

In [None]:
# Get Train/Validation datasets
train_dataset = get_dataset(TFRECORDS_TRAIN, val=False, debug=False)
val_dataset = get_dataset(TFRECORDS_VAL, val=True, debug=False)

TRAIN_STEPS_PER_EPOCH = len(TFRECORDS_TRAIN) * N_SAMPLES_TFRECORDS // BATCH_SIZE
VAL_STEPS_PER_EPOCH = len(TFRECORDS_VAL) * N_SAMPLES_TFRECORDS // BATCH_SIZE
print(f'TRAIN_STEPS_PER_EPOCH: {TRAIN_STEPS_PER_EPOCH}, VAL_STEPS_PER_EPOCH: {VAL_STEPS_PER_EPOCH}')

In [None]:
# Sanity check, image and label statistics
X_batch, y_batch = next(iter(get_dataset(TFRECORDS_TRAIN, val=False)))
image = X_batch['image'].numpy()
print(f'image shape: {image.shape}, y_batch shape: {y_batch.shape}')
print(f'image dtype: {image.dtype}, y_batch dtype: {y_batch.dtype}')
print(f'image min: {image.min():.2f}, max: {image.max():.2f}')

In [None]:
# Benchmark Dataset
benchmark_dataset(get_dataset(TFRECORDS_TRAIN, val=False))

In [None]:
# Show what we will be training on
show_batch(get_dataset(TFRECORDS_TRAIN, bs=16, val=False))

# Class Imbalance

In [None]:
# Label Distribution Train With Undersampled Majority Class
N = 128
train_labels = []
for _, labels in tqdm(get_dataset(TFRECORDS_TRAIN, val=False).take(N), total=N):
    train_labels += labels.numpy().tolist()
    
display(pd.concat((
        pd.Series(train_labels).value_counts(normalize=True).to_frame('Train Label Ratio'),
        pd.Series(train_labels).value_counts().to_frame('Train Label Count'),
    ), axis=1)
)

In [None]:
# Label Distribution Validation (Unchanged)
val_labels = []
for _, labels in tqdm(get_dataset(TFRECORDS_VAL, val=True), total=VAL_STEPS_PER_EPOCH):
    val_labels += labels.numpy().tolist()
    
display(pd.concat((
        pd.Series(val_labels).value_counts(normalize=True).to_frame('Val Label Ratio'),
        pd.Series(val_labels).value_counts().to_frame('Val Label Count'),
    ), axis=1)
)

# pF1 Metric

inspiration: [RSNA-BCD: EfficientNet [TF][TPU-1VM][Train]](https://www.kaggle.com/code/awsaf49/rsna-bcd-efficientnet-tf-tpu-1vm-train#Metric)

The source implementation is however buggy, it is a moving average which does not reset each epoch. The implementation below does reset each epoch.

In [None]:
class pF1(tf.keras.metrics.Metric):
    def __init__(self, name='pF1', **kwargs):
        super(pF1, self).__init__(name=name, **kwargs)
        self.tc = self.add_weight(name='tc', initializer='zeros')
        self.tp = self.add_weight(name='tp', initializer='zeros')
        self.fp = self.add_weight(name='fp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.tc.assign_add(tf.cast(tf.reduce_sum(y_true), tf.float32))
        self.tp.assign_add(tf.cast(tf.reduce_sum((y_pred[y_true == 1])), tf.float32))
        self.fp.assign_add(tf.cast(tf.reduce_sum((y_pred[y_true == 0])), tf.float32))

    def result(self):
        if self.tc == 0 or (self.tp + self.fp) == 0:
            return 0.0
        else:
            precision = self.tp / (self.tp + self.fp)
            recall = self.tp / (self.tc)
            return 2 * (precision * recall) / (precision + recall)

        def reset_state(self):
            self.tc.assign(0)
            self.tp.assign(0)
            self.fp.assign(0)

# Model

In [None]:
def normalize(image):
    # Repeat channels to create 3 channel images required by pretrained EfficientNetV2 models
    image = tf.repeat(image, repeats=3, axis=3)
    # Cast to float 32
    image = tf.cast(image, tf.float32)
    # Normalize with respect to ImageNet mean/std
    image = tf.keras.applications.imagenet_utils.preprocess_input(image, mode='torch')

    return image

In [None]:
def get_model():
    # Verify Mixed Policy Settings
    print(f'Compute dtype: {tf.keras.mixed_precision.global_policy().compute_dtype}')
    print(f'Variable dtype: {tf.keras.mixed_precision.global_policy().variable_dtype}')
    
    with STRATEGY.scope():
        # Set seed for deterministic weights initialization
        seed_everything()
        
        # Inputs, note the names are equal to the dictionary keys in the dataset
        image = tf.keras.layers.Input(INPUT_SHAPE, name='image', dtype=tf.uint8)
        
        # Normalize Input
        image_norm = normalize(image)
        
        # Normalize Input
        image_norm = normalize(image)

        # CNN Prediction in range [0,1]
        outputs = keras_efficientnet_v2.EfficientNetV2T(
            input_shape=[IMG_HEIGHT, IMG_WIDTH, 3],
            pretrained='imagenet',
            num_classes=1,
            classifier_activation='sigmoid',
            dropout=0.30,
        )(image_norm)

        # We will use the famous Adam optimizer for fast learning
        optimizer = tf.optimizers.Adam(learning_rate=LR_MAX, epsilon=1e-7, clipnorm=10.0)

        # Loss
        loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
        
        # Metrics
        metrics = [
            pF1(),
            tfa.metrics.F1Score(num_classes=1, threshold=0.50),
            tf.keras.metrics.Precision(),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.AUC(),
            tf.keras.metrics.BinaryAccuracy(),
        ]

        model = tf.keras.models.Model(inputs=image, outputs=outputs)
        
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

        return model

In [None]:
# Pretrained File Path: '/kaggle/input/sartorius-training-dataset/model.h5'
tf.keras.backend.clear_session()
# enable XLA optmizations
tf.config.optimizer.set_jit(True)

model = get_model()

In [None]:
# Plot model summary
model.summary()

In [None]:
# Model architecture
tf.keras.utils.plot_model(model, show_shapes=True, show_dtype=True, show_layer_names=True, expand_nested=False)

# Weight Initilization

In [None]:
# Validation metric on initialized model
_ = model.evaluate(
        get_dataset(TFRECORDS_VAL, val=True),
        verbose=VERBOSE,
        steps=VAL_STEPS_PER_EPOCH,
    )

In [None]:
# Train Output Baseline
val_preds = model.predict(
        get_dataset(TFRECORDS_VAL, val=True),
        verbose=VERBOSE,
        steps=128,
    ).squeeze()

In [None]:
# Initialized model train predictions: should not be saturated (all 0/1)
display(pd.Series(val_preds).describe().to_frame('Value'))

In [None]:
plt.figure(figsize=(15,8))
plt.title(f'Validation Predictions Initialized Model')
pd.Series(val_preds).plot(kind='hist')
plt.xticks(np.arange(0, 1.1, 0.1))
plt.grid()
plt.show()

# Learning Rate Scheduler

In [None]:
# Learning rate scheduler with logaritmic warmup and cosine decay
def lrfn(current_step, num_warmup_steps, lr_max, num_cycles=0.50, num_training_steps=N_EPOCHS):
    
    if current_step < num_warmup_steps:
        return lr_max * 0.10 ** (num_warmup_steps - current_step)
    else:
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))

        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * lr_max

In [None]:
# Plot the learning rate scheduler
def plot_lr_schedule(lr_schedule, epochs):
    fig = plt.figure(figsize=(20, 10))
    plt.plot([None] + lr_schedule + [None])
    # X Labels
    x = np.arange(1, epochs + 1)
    x_axis_labels = [i if epochs <= 40 or i % 5 == 0 or i == 1 else None for i in range(1, epochs + 1)]
    plt.xlim([1, epochs])
    plt.xticks(x, x_axis_labels) # set tick step to 1 and let x axis start at 1
    
    # Increase y-limit for better readability
    plt.ylim([0, max(lr_schedule) * 1.1])
    
    # Title
    schedule_info = f'start: {lr_schedule[0]:.1E}, max: {max(lr_schedule):.1E}, final: {lr_schedule[-1]:.1E}'
    plt.title(f'Step Learning Rate Schedule, {schedule_info}', size=18, pad=12)
    
    # Plot Learning Rates
    for x, val in enumerate(lr_schedule):
        if epochs <= 40 or x % 5 == 0 or x is epochs - 1:
            if x < len(lr_schedule) - 1:
                if lr_schedule[x - 1] < val:
                    ha = 'right'
                else:
                    ha = 'left'
            elif x == 0:
                ha = 'right'
            else:
                ha = 'left'
            plt.plot(x + 1, val, 'o', color='black');
            offset_y = (max(lr_schedule) - min(lr_schedule)) * 0.02
            plt.annotate(f'{val:.1E}', xy=(x + 1, val + offset_y), size=12, ha=ha)
    
    plt.xlabel('Epoch', size=16, labelpad=5)
    plt.ylabel('Learning Rate', size=16, labelpad=5)
    plt.grid()
    plt.show()

# Learning rate for encoder
LR_SCHEDULE = [lrfn(step, num_warmup_steps=N_WARMUP_EPOCHS, lr_max=LR_MAX, num_cycles=0.50) for step in range(N_EPOCHS)]
plot_lr_schedule(LR_SCHEDULE, epochs=N_EPOCHS)

In [None]:
# Learning Rate Callback
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda step: LR_SCHEDULE[step], verbose=1)

# Training

In [None]:
history = model.fit(
        train_dataset,
        steps_per_epoch = TRAIN_STEPS_PER_EPOCH,
        validation_data = val_dataset,
        epochs = N_EPOCHS,
        verbose = VERBOSE,
        callbacks = [
            lr_callback,
        ],
        class_weight = {
            0:  1.0,
            1: 16.0,
        },
    )

In [None]:
# Save model weights for inference
model.save_weights('model.h5')

# F1 By Threshold

In [None]:
# Get true labels and predictions for validation set
y_true_val = []
y_pred_val = []
for X_batch, y_batch in tqdm(get_dataset(TFRECORDS_VAL, val=True), total=VAL_STEPS_PER_EPOCH):
    y_true_val += y_batch.numpy().tolist()
    y_pred_val += model.predict_on_batch(X_batch).squeeze().tolist()

In [None]:
# source: https://www.kaggle.com/code/sohier/probabilistic-f-score
# Competition Leaderboard Metric
def pfbeta(labels, predictions, beta=1):
    y_true_count = 0
    ctp = 0
    cfp = 0

    for idx in range(len(labels)):
        prediction = min(max(predictions[idx], 0), 1)
        if (labels[idx]):
            y_true_count += 1
            ctp += prediction
        else:
            cfp += prediction

    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return 0

In [None]:
# Show
pf1_by_threshold = []
for t in tqdm(np.arange(0, 1.01, 0.01)):
    pf1_by_threshold.append(
        pfbeta(y_true_val, y_pred_val > t)
    )
    
plt.figure(figsize=(15,8))
plt.title('F1 By Threshold', size=24)
plt.plot(pf1_by_threshold, label='F1 Score')

arg_max = np.argmax(pf1_by_threshold)
val_max = np.max(pf1_by_threshold)
plt.scatter(arg_max, val_max, color='red', label=f'Best Threshold {t:.2f}, pF1 Score: {val_max:.2f}')

plt.xticks(np.arange(0, 110, 10), [f'{t:.2f}' for t in np.arange(0, 1.1, 0.1)])
plt.yticks(np.arange(0, 1.1, 0.1))
plt.xlim(0, 100)
plt.ylim(0, 1)
plt.xlabel('Threshold')
plt.ylabel('pF1 Score')
plt.legend(fontsize=12)
plt.grid()
plt.show()

# Training History

In [None]:
def plot_history_metric(metric, f_best=np.argmax, ylim=None, yscale=None, yticks=None):
    plt.figure(figsize=(20, 10))
    
    values = history.history[metric]
    N_EPOCHS = len(values)
    val = 'val' in ''.join(history.history.keys())
    # Epoch Ticks
    if N_EPOCHS <= 20:
        x = np.arange(1, N_EPOCHS + 1)
    else:
        x = [1, 5] + [10 + 5 * idx for idx in range((N_EPOCHS - 10) // 5 + 1)]

    x_ticks = np.arange(1, N_EPOCHS+1)

    # Validation
    if val:
        val_values = history.history[f'val_{metric}']
        val_argmin = f_best(val_values)
        plt.plot(x_ticks, val_values, label=f'val')

    # summarize history for accuracy
    plt.plot(x_ticks, values, label=f'train')
    argmin = f_best(values)
    plt.scatter(argmin + 1, values[argmin], color='red', s=75, marker='o', label=f'train_best')
    if val:
        plt.scatter(val_argmin + 1, val_values[val_argmin], color='purple', s=75, marker='o', label=f'val_best')

    plt.title(f'Model {metric}', fontsize=24, pad=10)
    plt.ylabel(metric, fontsize=20, labelpad=10)

    if ylim:
        plt.ylim(ylim)

    if yscale is not None:
        plt.yscale(yscale)
        
    if yticks is not None:
        plt.yticks(yticks, fontsize=16)

    plt.xlabel('epoch', fontsize=20, labelpad=10)        
    plt.tick_params(axis='x', labelsize=8)
    plt.xticks(x, fontsize=16) # set tick step to 1 and let x axis start at 1
    plt.yticks(fontsize=16)
    
    plt.legend(prop={'size': 10})
    plt.grid()
    plt.show()

In [None]:
plot_history_metric('loss', f_best=np.argmin)

In [None]:
plot_history_metric('pF1', ylim=[0,1], yticks=np.arange(0.0, 1.1, 0.1))

In [None]:
plot_history_metric('f1_score', ylim=[0,1], yticks=np.arange(0.0, 1.1, 0.1))

In [None]:
plot_history_metric('precision', ylim=[0,1], yticks=np.arange(0.0, 1.1, 0.1))

In [None]:
plot_history_metric('recall', ylim=[0,1], yticks=np.arange(0.0, 1.1, 0.1))

In [None]:
plot_history_metric('auc', ylim=[0,1], yticks=np.arange(0.0, 1.1, 0.1))

In [None]:
plot_history_metric('binary_accuracy', ylim=[0,1], yticks=np.arange(0.0, 1.1, 0.1))