
Model training with training data generated on the fly
=======================================================

This example demonstrates how to generate training data on the fly during training.
In some cases, data generation speed is fast enough to be used on the fly for model training. 

Here, the example demonstrates single source model training similar as in [Kuj19]_, but without simulation the time data (dataset2). Instead, the CSM is calculated from the Wishart distributed source power matrix $Q$.


## Build the dataset generator

At first, we manipulate the dataset config to only create single source examples on a smaller grid  of size $51 \times 51$

In [2]:
import os
import tensorflow as tf 
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # change tensorflow log level for doc purposes
from acoupipe.datasets.dataset2 import Dataset2, DEFAULT_GRID, DEFAULT_MICS

DEFAULT_GRID.increment = 1/50*DEFAULT_MICS.aperture

# training dataset
dataset = Dataset2(max_nsources = 1, f=1000, features=["sourcemap"])       

# build datasets for training and validation
training_dataset = dataset.get_tf_dataset(split="training",size=100000)
validation_dataset = dataset.get_tf_dataset(split="validation",size=10000)

2023-02-02 11:34:58.812827: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-02-02 11:34:58.812998: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
2023-02-02 11:34:58.813024: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (adku): /proc/driver/nvidia/version does not exist
2023-02-02 11:34:58.814141: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


The TensorFlow dataset API can be used to build a data pipeline from the data generator. Here, batches with 32 source cases are used.

In [4]:

def yield_features_and_labels(data):   
    feature = data['sourcemap'][0]
    f_max = tf.reduce_max(feature)
    feature /= f_max
    label = tf.concat([data['loc'][:2],data['p2'][:,:,0,0]/f_max],axis=0)
    return (feature,label)

training_dataset = training_dataset.map(yield_features_and_labels).batch(16).repeat()
validation_dataset = validation_dataset.map(yield_features_and_labels).batch(16)


Now, one can build the ResNet50V2 model and use the data to fit the model. This may take several hours, depending on the computational infrastructure. 

In [6]:
# build model architecture
model = tf.keras.Sequential(
    tf.keras.applications.resnet_v2.ResNet50V2(
    include_top=False,
    weights=None,
    input_shape=(51,51,1),
    ))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(3, activation=None))

# compile and fit
model.compile(optimizer=tf.optimizers.Adam(1.5*10e-4),loss='mse')
model.fit(training_dataset,validation_data=validation_dataset, epochs=25,steps_per_epoch=1000)

Epoch 1/25
  25/1000 [..............................] - ETA: 9:16 - loss: 13.7272

KeyboardInterrupt: 

After successfully training, the model can be used for source characteristic prediction.

In [None]:
import matplotlib.pyplot as plt
from acoular import L_p

# use the first sample from the validation dataset
generator = validation.generate()
data = next(generator)
sourcemap = data['sourcemap']
sourcemap /= sourcemap.max()
prediction = model.predict(sourcemap)[0]

plt.figure()
plt.imshow(L_p(sourcemap.squeeze()).T,
            vmax=L_p(sourcemap.max()),
            vmin=L_p(sourcemap.max())-15,
            extent=validation.grid.extend(),
            origin="lower")
plt.plot(prediction[1],prediction[2],'x',label="prediction")
plt.plot(data['loc'][0],data['loc'][1],'x',label="label")
plt.colorbar()
plt.legend()