In [None]:
# Dataset management packages
from spivutils.synthetic_datasets.spid import load_data
from spivutils.batch_generators.keras_generator import batch_data
from spivutils.common_tools.pre_processors import preprocess_data
from spivutils.common_tools.operations import normalization, vectoraddition, thresholding

# General purpose packages
import numpy as np

# Model packages
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv3D, MaxPooling3D, Flatten, Dense, Reshape

In [None]:
# Force CPU usage
tf.config.set_visible_devices([], 'GPU')

In [None]:
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data()

In [None]:
batch_size = 10
chunk_size = 20

# Collect a chunk of the dataset
train_x_chunk = train_x[0:chunk_size]
train_y_chunk = train_y[0:chunk_size]
valid_x_chunk = valid_x[0:chunk_size]
valid_y_chunk = valid_y[0:chunk_size]

# Load data in batches to avoid memory overload
train_batch = batch_data(train_x_chunk, train_y_chunk, batch_size)
valid_batch = batch_data(valid_x_chunk, valid_y_chunk, batch_size)

Pre-processing

In [None]:
# Add pre-processing operations
train_batch.add_x_preprocessing_operation(thresholding)
train_batch.add_x_preprocessing_operation(normalization)
train_batch.add_y_preprocessing_operation(vectoraddition)

valid_batch.add_x_preprocessing_operation(thresholding)
valid_batch.add_x_preprocessing_operation(normalization)
valid_batch.add_y_preprocessing_operation(vectoraddition)

In [None]:
input_shape = train_batch[0][0][0,].shape
output_shape = train_batch[0][1][0,].shape

# Model hyperparameters
filters = 32
units = 64
kernel_size = (1,10,10)
pool_size = (1,5,5)

In [None]:
model = Sequential()

model.add(Conv3D(filters = filters, kernel_size = kernel_size, activation = 'relu', input_shape = input_shape))
model.add(MaxPooling3D(pool_size = pool_size))

model.add(Flatten())

model.add(Dense(units = units, activation = 'relu'))
model.add(Dense(np.prod(output_shape)))

model.add(Reshape(output_shape))

model.compile(loss = 'mean_squared_error', optimizer = 'adam')

In [None]:
model.summary()

In [None]:
model.fit(train_batch, validation_data = valid_batch, epochs = 1)