In [None]:
import itertools
import warnings
import time
import os
import uuid
import gc
import random as rn

from logging import getLogger, Formatter, StreamHandler, FileHandler, INFO
from typing import List, NoReturn, Union, Tuple, Optional, Text, Generic, Callable, Dict
from tqdm import tqdm_notebook as tqdm
from contextlib import contextmanager
from joblib import Parallel, delayed
from pathlib import Path

import pywt
import h5py
import vaex
#vaex.multithreading.thread_count_default = 8
import vaex.ml
import dask.dataframe as dd
import dask.array as da
import einops as eo
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
import scipy.stats as stats
import cufflinks as cf
import plotly.offline as pyo
import plotly.graph_objs as go
import matplotlib.pyplot as plt

import logging
keras = tf.keras
layers = keras.layers

from IPython.display import display
from dask_ml.preprocessing import OneHotEncoder, LabelEncoder, StandardScaler, MinMaxScaler
from dask.distributed import Client
from sklearn.pipeline import Pipeline
from sklearn.utils import column_or_1d
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import f1_score, cohen_kappa_score, mean_squared_error
from sklearn.model_selection import KFold, GroupKFold
from sklearn import preprocessing
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback, ReduceLROnPlateau, LearningRateScheduler
from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, mean_squared_error
from tensorflow.keras.optimizers import Adam, RMSprop, SGD
from tensorflow.keras.utils import Sequence, to_categorical
from tensorflow.keras import losses, models, optimizers
from tensorflow.keras import backend as K
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import constraints
from tensorflow.keras.utils import get_custom_objects
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from scipy.fft import fft, fftfreq, fftshift
from dask.distributed import Client, LocalCluster

from dask_utils import (swapaxes_shuffle, shuffle_blocks_together,
                        stack_interleave_flatten, chunk_generator)

#cluster = LocalCluster(n_workers=1, threads_per_worker=1)
#c = Client(cluster)

# set plotly in notebook mode
init_notebook_mode(connected=True)
# likewise cufflinks for offline use
cf.go_offline()

warnings.simplefilter('ignore')
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', 1000)
pd.set_option('display.max_rows', 500)
%matplotlib inline

log = logging.getLogger(__name__)

# Load radio signal burst data

This example uses the "DeepSig RADIOML 2018.01A" open dataset provided at:

https://www.deepsig.ai/datasets

The dataset includes both synthetic simulated channel effects and over-the-air recordings of 24 digital and analog modulation types which has been heavily validated.

This dataset was used for [Over-the-air deep learning based radio signal classification](https://arxiv.org/pdf/1712.04578.pdf) published 2017 in IEEE Journal of Selected Topics in Signal Processing, which provides additional details and description of the dataset.

Data are stored in hdf5 format as complex floating point values, with 2 million examples, each 1024 samples long.

These  include  a  number  of  high  ordermodulations (QAM256 and APSK256), which are used in thereal world in very high-SNR low-fading channel environments such as on [impulsive satellite links](https://www.researchgate.net/publication/280972230_Transmission_parameters_optimization_and_receiver_architectures_for_DVB-S2X_systems) (e.g. DVB-S2X).

In [None]:
data_path = '../data/2018.01/GOLD_XYZ_OSC.0001_1024.hdf5'
X_shf_path = '../data/shf.hdf5' # precalculated shuffled sequences

# named types of radio signal for each set of samples in order
classes = [
    '32PSK','16APSK','32QAM','FM','GMSK','32APSK','OQPSK','8ASK',
    'BPSK','8PSK','AM-SSB-SC','4ASK','16PSK','64APSK','128QAM','128APSK',
    'AM-DSB-SC','AM-SSB-WC','64QAM','QPSK','256QAM','AM-DSB-WC','OOK','16QAM'
]
classes = pd.Series(classes).T

Load the arrays for each 1024-sample-long example burst into dask arrays.

- X represents the two-dimensional (I, Q) values for every sample in the burst
- Y represents the one-hot encoded category (corresponding to the signal modulation types above) for each burst
- Z represents the signal-to-noise-ratio (SNR) for each burst

In [None]:
f = h5py.File(data_path, "r")
print("datasets found: ", list(f.keys()))
for k in f.keys():
    print(k, "has shape:", f[k].shape)
# read arrays with Dask
X = da.from_array(f['X'], chunks=('auto', 1024, 2)) # samples * timesteps * features (I, Q)
Y = da.from_array(f['Y'], chunks=(X.chunks[0], -1)) # samples * one-hot-encoded class
Z = da.from_array(f['Z'], chunks=(X.chunks[0], -1)) # samples * SNR

# Visualise data

See how the samples are arranged by classification/SNR across the dataset

In [None]:
df_clf_snr = pd.DataFrame(
    data=np.vstack(
        [
            da.argmax(Y, axis=1).compute(),
            Z.compute().flatten()
        ]
    ).T,
    columns=['label', 'SNR']
)
# visualise SNR monotonically increasing for each classification across dataset
axs = df_clf_snr.groupby('label')['SNR'].plot()

cut on SNR > min value (paper reports <10% accuracy for the high negative SNR samples)

In [None]:
min_snr = 8
df_clf_snr_filtered = df_clf_snr.query(f'SNR >= {min_snr}')

In [None]:
# 49152 burst samples per class with SNR > 8
# 4096 samples per SNR per label
print(df_clf_snr_filtered['SNR'].unique().shape[0], "unique SNR values")
df_clf_snr_filtered.reset_index().groupby(['label', 'SNR']).describe()

Pick out samples with SNR in range

In [None]:
X_filtered = X[df_clf_snr_filtered.index.values]
Y_filtered = Y[df_clf_snr_filtered.index.values]
Z_filtered = Z[df_clf_snr_filtered.index.values]

Grab one signal sample from each class for visualisations

In [None]:
first_samples = (
    df_clf_snr_filtered.reset_index(drop=True).reset_index().groupby(
        ['label']
    ).first()
)
first_sample_indices = first_samples['index'].values

X_samples = X_filtered[first_sample_indices].compute()
Y_samples = Y_filtered[first_sample_indices].compute()
Z_samples = Z_filtered[first_sample_indices].compute()

# 24 samples total of sequences of shape 1024 (samples) * 2 (compoents)
print(X_samples.shape)
print(Y_samples.shape)
print(Z_samples.shape)

From the paper: 

"We  use  asecond  B210  (with  a  separate  free-running  LO)  to  receive these transmissions in the lab, over a relatively benign indoor wireless channel on the 900MHz ISM band."

I'll assume this means the sampling frequency is 900MHz.

In [None]:
sampling_rate = 9.0e8
N_samples = X_samples.shape[1]
T = 1.0 / sampling_rate
t = np.linspace(0.0, N_samples*T, N_samples)

Visualise I/Q values over time in 3D for each class

In [None]:
# restrict which data to plot (crowded 3D plots are messy)
n_classes_to_visualise = 3
# undersample by a factor of show_nth_sample for clarity
show_nth_sample = 16
# plot and superimpose each example sequence
traces = [
    go.Scatter3d(
        name=classes[first_samples.index][i],
        x=X_samples[i, ::show_nth_sample, 0],
        y=t,
        z=X_samples[i, ::show_nth_sample, 1],
        marker=dict(
            size=4,
            colorscale='Viridis',
        ),
        line=dict(
            width=2
        )
    )
    for i in range(X_samples[:n_classes_to_visualise].shape[0])
]
fig = go.Figure(data=traces)
# set figure layout
fig.update_layout(
    title="I/Q plot over sample sequences of various signal types",
    autosize=False,
    scene=dict(
        xaxis_title='I',
        yaxis_title='time',
        zaxis_title='Q',
        aspectratio = dict( x=1, y=3., z=1 ),
        aspectmode = 'manual'
    ),
)
fig.show()

## in-phase component

### time domain

In [None]:
show_nth_sample = 1
n_classes_to_visualise = 4

traces = [
    go.Scatter(
        name=classes[first_samples.index][i],
        x=t,
        y=X_samples[i, ::show_nth_sample, 0],
        marker=dict(
            size=4,
            colorscale='Viridis',
        ),
        line=dict(
            width=2
        )
    )
    for i in range(X_samples[:n_classes_to_visualise].shape[0])
]
fig = go.Figure(data=traces)
fig.show()

### frequency domain

https://docs.scipy.org/doc/scipy/reference/tutorial/fft.html

In [None]:
n_classes_to_visualise = 4
show_nth_freq_component = 1
traces = []

for ix in range(n_classes_to_visualise):
    # magnitude of fourier components
    yf = fft(X_samples[ix, :, 0] + 1j*X_samples[ix, :, 1])
    yf = 1/N_samples * np.abs(fftshift(yf))
    # frequency range
    xf = np.linspace(0.0, 1.0/(1.0*T), N_samples)
    xf = fftfreq(N_samples, T)
    xf = fftshift(xf)
    plot_data = go.Scatter(
        name=classes[first_samples.index][ix],
        x=xf[::show_nth_freq_component],
        y=yf[::show_nth_freq_component],
        marker=dict(
            size=4,
            colorscale='Viridis',
        ),
        line=dict(
            width=2
        )
    )
    traces.append(plot_data)
fig = go.Figure(data=traces)
fig.show()

verify swapaxes shuffle is reordering samples as anticipated, traversing permutations of label/SNR first

In [None]:
pd.DataFrame(
    data=np.vstack(
        [
            da.argmax(swapaxes_shuffle(Y_filtered), axis=1).compute(),
            da.squeeze(swapaxes_shuffle(Z_filtered)).compute()
        ]
    ).T,
    columns=['label', 'SNR']
).head(32)

perform the swapaxes shuffling on each array to intersperse examples

In [None]:
# for the burst samples themselves, go through an intermediate hdf5 file to avoid dask scheduler
# having a fit and blowing up RAM...
if not Path(X_shf_path).exists():
    swapaxes_shuffle(X_filtered).to_hdf5('shf.hdf5', '/X')
f = h5py.File(X_shf_path, "r")
print("datasets found: ", list(f.keys()))
for k in f.keys():
    print(k, "has shape:", f[k].shape)
# read arrays with Dask
X_filtered = da.from_array(f['X'], chunks=('auto', -1, -1)) # samples * timesteps * features (I, Q)

# do classifications and SNRs regularly
Y_filtered = swapaxes_shuffle(Y_filtered).rechunk(X_filtered.chunks[0], -1)
Z_filtered = swapaxes_shuffle(Z_filtered).rechunk(X_filtered.chunks[0], -1)

print(X_filtered.numblocks[0], "blocks of size", X_filtered.chunksize[0])
assert X_filtered.numblocks[0] == Y_filtered.numblocks[0]

In [None]:
X_sum = X_filtered.sum(axis=1).compute()

Split into 53 train / 10 validation / 10 test

In [None]:
X_train, X_val, X_test = X_filtered.blocks[:53], X_filtered.blocks[53:63], X_filtered.blocks[63:]
Y_train, Y_val, Y_test = Y_filtered.blocks[:53], Y_filtered.blocks[53:63], Y_filtered.blocks[63:]
Z_train, Z_val, Z_test = Z_filtered.blocks[:53], Z_filtered.blocks[53:63], Z_filtered.blocks[63:]

# Preprocessing

In [None]:
def make_X_pipeline():
    """
    Input data training pipeline. Simple Z-score rescaling.
    
    Returns
    -------
    :obj:`sklearn.pipeline.Pipeline`
        A pipeline for preprocessing input signal samples. Currently
        just simple scaling.
    """
    steps = [
        ('scaler', StandardScaler()),
    ]
    return Pipeline(steps)

Perform preprocessing before shuffling to save computations

In [None]:
# apply z-score scaling to samples (~2m)
X_pipeline = make_X_pipeline()
X_train = X_pipeline.fit_transform(X_train)
X_val = X_pipeline.transform(X_val)
X_test = X_pipeline.transform(X_test)

Shuffle and prepare batch generation from delayed arrays

In [None]:
def fancy_batch_generator(X:da.Array,
                          Y:da.Array,
                          batch_size:int,
                          augger=None,
                          client=None,
                          seed:int=42,
                          shuffle_blocks_every_epoch:bool=False,
                          shuffle_within_blocks:bool=True,
                          float32=True):
    """
    Generates batches of image/mask pairs from dask arrays with augmentations.
    Proceeds chunk by chunk through the dask array, generating smaller numpy
    arrays of the appropriate `batch_size` from these as it goes along. Applies
    image and mask augmentations to each of these.
    Parameters
    ----------
    imgs: :obj:`dask.array.Array`
        A dask array of images distributed along axis 0
    masks: :obj:`dask.array.Array`
        A dask array of masks distributed along axis 0
    batch_size: int
        The batch size of the yielded arrays
    augger: callable, optional
        A function applying augmentations to batches (X, Y)
    client: :obj:`distributed.Client`, optional
        Dask distributed client for mapping parallel augmentation jobs
    pyramid_y: bool, optional
        Instead of returning masks, return a list of masks at 
        [8th, 4th, half, native] resolution. for attention pyramid unet.
    seed: int
        Random seed
    shuffle: bool
        Flags whether to shuffle within the chunks.
    Yields
    ------
    tuple of :obj:`numpy.ndarray`:
        (images, masks) with augmentations applied and shape (batch_size, h, w, c)
    """
    # instantiate distributed scheduler if not passed
    if client is None:
        close_client_after = True
        client = Client() 
    # track nsamples
    n_samples = X.shape[0]
    # seed numpy
    if seed:
        np.random.seed(seed)
    # keep yielding
    epochs = 0
    # -- loop over whole dask array
    while True:
        # get chunk generator for larger-than-memory image and mask arrays
        chunk_gen =  chunk_generator([X, Y], shuffle_blocks=shuffle_blocks_every_epoch)
        # for each image/mask chunk in RAM, do a chunk-epoch worth of
        # data-augmented batch generation
        # -- loop over chunks: these are now numpy arrays
        for i, (X_chunk, Y_chunk) in enumerate(chunk_gen):
            log.debug(f"Dask Chunk: {i+1}\n")
            # get indices of chunk
            n_samples_chunk = X_chunk.shape[0]
            n_batches_chunk = int(n_samples_chunk/batch_size)
            index = np.arange(0, n_samples_chunk)
            # optionally shuffle inplace
            if shuffle_within_blocks:
                np.random.shuffle(index)
            start_point, batches = 0, 0 # track processed samples and batches
            # -- loop over batches in chunk
            while True:
                log.debug(f"Batch: {batches}")
                inds = index[start_point:start_point+batch_size]
                Xb, Yb = X_chunk[inds], Y_chunk[inds]
                # augment this batch
                if augger is not None:
                    Xb, Yb = augment_all(Xb, Yb, augger, client)
                if float32:
                    Xb = Xb.astype('float32')
                    Yb = Yb.astype('float32')
                # try to clean up autocreated client if tf stops iteration
                try:
                    yield Xb, Yb
                except StopIteration:
                    if close_client_after:
                        client.close()
                    raise
                start_point += batch_size
                batches += 1
                # stop if we reach the end of the chunk
                if batches >= n_batches_chunk:
                    break
        # -- epoch end
        epochs += 1
        log.debug(f"Finished generating epoch: {epochs}!\n")

In [None]:
# test block generator
for Xb_train, Yb_train in chunk_generator([X_train, Y_train]):
    assert Xb_train.shape[0] == Yb_train.shape[0]
    break
#print("should see most if not all classes in the first block with sufficient shuffling:")
#print("classes in first block: ", np.unique(np.argmax(Yb_train, axis=1)))
# if not, increase "repeats" in synchronised pseudoshuffle. this incrases memory usage and computation time.

In [None]:
# test batch generator
for X_batch, Y_batch in fancy_batch_generator(X_train, Y_train, batch_size=64):
    assert X_batch.shape[0] == Y_batch.shape[0]
    break

In [None]:
print(X_batch.shape, Y_batch.shape)

# WaveNet classifier

In [None]:
def scaled_dot_product_attention(q, k, v, mask, attention_dropout=0., trainable=False):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.

    Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
    output, attention_weights
    """

    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
    if trainable:
        attention_weights = tf.nn.dropout(attention_weights, 1.0 - attention_dropout)
    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights

In [None]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self,
                 d_model,
                 num_heads,
                 attention_dropout=0.1,
                 trainable=True,
                 name='MultiHeadAttention'):
        """
        Adapted from Google's Transformer implementation at:
            https://www.tensorflow.org/tutorials/text/transformer#masking
        """
        super(MultiHeadAttention, self).__init__(name=name)
        self.num_heads = num_heads
        self.d_model = d_model
        self.trainable = trainable
        self.attention_dropout = attention_dropout
        
        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask=None):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask, trainable=self.trainable, attention_dropout=self.attention_dropout)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights

In [None]:
# test SA layer
#hidden_size = 2
#batch_size = 8
#X0_batch = X0[:batch_size]
#sa = MultiHeadAttention(hidden_size, num_heads=1)
#sa(X0_batch, X0_batch, X0_batch)

In [None]:
class Mish(layers.Activation):
    '''
    Mish Activation Function.
    .. math::
        mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
    Shape:
        - Input: Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.
        - Output: Same shape as the input.
    Examples:
        >>> X = Activation('Mish', name="conv1_act")(X_input)
    '''

    def __init__(self, activation, **kwargs):
        super(Mish, self).__init__(activation, **kwargs)
        self.__name__ = 'Mish'


def mish(inputs):
    return inputs * tf.math.tanh(tf.math.softplus(inputs))

get_custom_objects().update({'Mish': Mish(mish)})

In [None]:
def WaveNetResidualConv1D(num_filters, kernel_size, stacked_layer):

    def build_residual_block(l_input):
        resid_input = l_input
        for dilation_rate in [2**i for i in range(stacked_layer)]:
            l_sigmoid_conv1d = layers.Conv1D(
                num_filters, kernel_size, dilation_rate=dilation_rate,
                padding='same', activation='sigmoid'
            )(l_input)
            l_tanh_conv1d = layers.Conv1D(
                num_filters, kernel_size, dilation_rate=dilation_rate,
                padding='same', activation='Mish'
            )(l_input)
            l_input = layers.Multiply()([l_sigmoid_conv1d, l_tanh_conv1d])
            l_input = layers.Conv1D(num_filters, 1, padding='same')(l_input)
            resid_input = layers.Add()([resid_input ,l_input])
        return resid_input
    return build_residual_block

def WaveNetClassifier(shape_):
    num_filters_ = 16
    kernel_size_ = 3
    stacked_layers_ = [12, 8, 4, 1]
    LR = 0.0001
    output_dim = 24
    l_input = layers.Input(shape=(shape_))
    x = layers.Conv1D(num_filters_, 1, padding='same')(l_input)
    x = WaveNetResidualConv1D(num_filters_, kernel_size_, stacked_layers_[0])(x)
    x = layers.Conv1D(num_filters_*2, 1, padding='same')(x)
    x = WaveNetResidualConv1D(num_filters_*2, kernel_size_, stacked_layers_[1])(x)
    x = layers.Conv1D(num_filters_*4, 1, padding='same')(x)
    x = WaveNetResidualConv1D(num_filters_*4, kernel_size_, stacked_layers_[2])(x)
    x = layers.Conv1D(num_filters_*8, 1, padding='same')(x)
    x = WaveNetResidualConv1D(num_filters_*8, kernel_size_, stacked_layers_[3])(x)
    # collapse sequences to output class vector
    x = layers.Conv1D(output_dim, 3, padding='same', activation='relu')(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Softmax()(x)
    #x = layers.Dropout(0.1)(x)
    #x = layers.Dense(20, activation="relu")(x)
    #x = layers.Dropout(0.1)(x)
    #outputs = layers.Dense(2, activation="softmax")(x)
    #print(x.shape)
    #l_output = layers.Dense(output_dim, activation='softmax')(x)
    #print(l_input.shape, l_output.shape)
    model = models.Model(inputs=[l_input], outputs=[outputs])
    return model
    opt = Adam(lr=LR)
    opt = tfa.optimizers.SWA(opt)
    model.compile(loss=losses.CategoricalCrossentropy(), optimizer=opt, metrics=['accuracy'])
    return model

In [None]:
model = WaveNetClassifier(X_train.shape[1:])
model.summary()

In [None]:
epochs = 10
batch_size = 64
optimiser = 'adam'
lr_init = 1e-4
lr_reduce_factor = 0.5
lr_reduce_patience = 2
lr_min = 1e-6
patience = 3

train_steps = X_train.shape[0]//batch_size
val_steps = X_val.shape[0]//batch_size
batches_per_block = X_train.blocks[0].shape[0] / batch_size
print(train_steps, "train steps")
print(batches_per_block, "batches per block") # don't want to queue up more than one extra block for RAM

In [None]:
# interpret optimizer
if optimiser == 'sgd':
    opt = tf.keras.optimizers.SGD(
        learning_rate=lr_init, momentum=0.85, nesterov=False
    )
elif optimiser == 'adam':
    opt = tf.keras.optimizers.Adam(
        learning_rate=lr_init, beta_1=0.9, beta_2=0.999, amsgrad=False
    ) # check out RADAM?
else:
    raise ValueError(f"Optimiser {opt} not understood")        
opt = tfa.optimizers.SWA(opt)
    
# specify training directory to save weights and metrics for this loss_fn and data ID
# within models_dir
models_dir = Path('../models')
project_name = Path(f'WaveNetClassifier_{uuid.uuid4()}')
training_dir = Path(models_dir) / project_name
training_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# -- callbacks

# early stopping
monitor = 'val_loss'
callbacks=[
    tf.keras.callbacks.EarlyStopping(monitor, patience=patience)
]

# reduce the learning rate on plateaus
callbacks.append(
    tf.keras.callbacks.ReduceLROnPlateau(monitor=monitor,
                                         factor=lr_reduce_factor,
                                         patience=lr_reduce_patience,
                                         min_lr=lr_min)
)

# set up tensorboard to record metrics in a subdirectory
tb_pth = training_dir / Path("metrics/")
tb_cb = tf.keras.callbacks.TensorBoard(
    log_dir=str(tb_pth),
    update_freq=50
)
callbacks.append(tb_cb)

In [None]:
# set up checkpoints in the training directory
cp_fmt = 'cp-e{epoch:02d}-l{loss:.5f}-a{accuracy:.4f}'
suffix = '-vl{val_loss:.5f}-va{val_accuracy:.4f}.ckpt'
cp_fmt = cp_fmt + suffix
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=str(training_dir / Path(cp_fmt)), # saved_model
    monitor=monitor,
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

In [None]:
model.compile(loss=losses.CategoricalCrossentropy(), optimizer=opt, metrics=['accuracy'])

In [None]:
# quickly check
X_train_blk_0 = X_train.blocks[0].compute()
Y_train_blk_0 = Y_train.blocks[0].compute()
Y_train_blk_0_cls = [classes[ix] for ix in np.argmax(Y_train_blk_0, axis=1)]

In [None]:
show_nth_sample = 1
n_classes_to_visualise = 24

traces = [
    go.Scatter(
        name=Y_train_blk_0_cls[i],
        x=t,
        y=X_train_blk_0[i, ::show_nth_sample, 0],
        marker=dict(
            size=4,
            colorscale='Viridis',
        ),
        line=dict(
            width=2
        )
    )
    for i in range(X_train_blk_0[:n_classes_to_visualise].shape[0])
]
fig = go.Figure(data=traces)
fig.show()

In [None]:
model.fit(
    fancy_batch_generator(X_train, Y_train, batch_size=batch_size, shuffle_blocks_every_epoch=False, shuffle_within_blocks=False),
    steps_per_epoch=train_steps, # steps per epoch
    epochs=epochs,
    validation_data=fancy_batch_generator(X_val, Y_val, batch_size=batch_size),
    validation_steps=val_steps,
    max_queue_size=batches_per_block-1,
    callbacks=callbacks
)

In [None]:
1024 * 1024 * 16 / (1024^3)

In [None]:
#c(X0_batch)