# Making the model

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from keras import layers

import io
import imageio
from IPython.display import Image, display
from ipywidgets import widgets, Layout, HBox

In [2]:
def conv_block(x, n_filters):
    x = layers.Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    
    return x

In [3]:
def down_samp(x, n_filters):
    x = conv_block(x, n_filters)
    skip = layers.MaxPooling2D(padding='same')(x)
    return x, skip

In [4]:
def up_samp(x, skip, n_filters):
    x = layers.Conv2DTranspose(n_filters, 2, 2, padding='same')(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block(x, n_filters)
    return x

In [5]:
def UNet(input_shape, out_c, n_filters=8):
    x = keras.Input(input_shape)
    
    # downsampling
    d1, p1 = down_samp(x, n_filters)
    d2, p2 = down_samp(p1, n_filters*2)
    d3, p3 = down_samp(p2, n_filters*4)
    d4, p4 = down_samp(p3, n_filters*8)
    
    # bottleneck
    b = conv_block(p4, n_filters*16)
    
    # upsampling
    u1 = up_samp(b, d4, n_filters*8)
    u2 = up_samp(u1, d3, n_filters*4)
    u3 = up_samp(u2, d2, n_filters*2)
    u4 = up_samp(u3, d1, n_filters)
    
    y = layers.Conv2D(out_c, (1, 1), padding='same')(u4)
    
    model = keras.Model(x, y, name='u-net')
    
    return model
    

# The Data

In [6]:
fpath = 'data/sim_np/size64/sim_512x64x64x64x3.npy'
dataset = np.load(fpath)
fpath = 'data/sim_np/size64/bound_64x64.npy'
boundary = np.load(fpath)

# Swap the axes representing the number of frames and number of data samples.
# dataset = np.swapaxes(dataset, 0, 1)
# We'll pick out 1000 of the 10000 total examples and use those.
# dataset = dataset[:1000, ...]
# Add a channel dimension since the images are grayscale.
# dataset = np.expand_dims(dataset, axis=-1)

# # Split into train and validation sets using indexing to optimize memory.
indexes = np.arange(dataset.shape[0])
np.random.shuffle(indexes)
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]) :]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]

# Normalize the data to the 0-1 range.
# train_dataset = train_dataset / 255
# val_dataset = val_dataset / 255

# We'll define a helper function to shift the frames, where
# `x` is frames 0 to n - 1, and `y` is frames 1 to n.
def create_shifted_frames(data, boundary):
    x = np.zeros((data.shape[0], data.shape[1] - 1, data.shape[2], data.shape[3], data.shape[4] + 1), np.float16)
    y = np.zeros((data.shape[0], data.shape[1] - 1, data.shape[2], data.shape[3], data.shape[4]), np.float16)
    
    for i in range(data.shape[0]):
        for j in range(data.shape[1] - 1):
            
            x[i, j] = np.concatenate((data[i, j], np.expand_dims(boundary, axis=-1)), axis=-1)
            y[i, j] = data[i, j + 1]
        
    return x, y


# Apply the processing function to the datasets.
x_train, y_train = create_shifted_frames(train_dataset, boundary)
x_val, y_val = create_shifted_frames(val_dataset, boundary)

# Inspect the dataset.
print("Training Dataset Shapes: " + str(x_train.shape) + ", " + str(y_train.shape))
print("Validation Dataset Shapes: " + str(x_val.shape) + ", " + str(y_val.shape))

Training Dataset Shapes: (460, 63, 64, 64, 4), (460, 63, 64, 64, 3)
Validation Dataset Shapes: (52, 63, 64, 64, 4), (52, 63, 64, 64, 3)


# Training

In [7]:
def unroll_frames(x):
    return x.reshape(x.shape[0]*x.shape[1], x.shape[2], x.shape[3], x.shape[4])

In [8]:
# Fit the model to the training data.
X = unroll_frames(x_train)
Y = unroll_frames(y_train)
X_val = unroll_frames(x_val)
Y_val = unroll_frames(y_val)

In [9]:
# Define modifiable training hyperparameters.
epochs = 10
batch_size = 128

# Define some callbacks to improve training.
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=2)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=0)

# create model to train
model = UNet((64, 64, 4), 3, 8)
model.compile(
    loss=keras.losses.MeanSquaredError(), optimizer=keras.optimizers.Adam(),
)

# fit data to model
history = model.fit(
    X,
    Y,
    batch_size=batch_size,
    epochs=epochs,
    validation_data = (X_val, Y_val),
    callbacks=[early_stopping, reduce_lr],
)

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



2022-04-14 16:45:47.070830: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-04-14 16:45:47.074036: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Epoch 1/10


2022-04-14 16:46:13.825807: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-04-14 16:46:16.316281: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.




2022-04-14 16:46:42.765476: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [10]:
model.save('models/keras/unet_mseloss')

2022-04-14 16:49:51.477446: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: models/keras/unet_mseloss/assets


In [11]:
history.history

{'loss': [1.002987027168274,
  0.5214625000953674,
  0.4473508596420288,
  0.40680235624313354,
  0.3790105879306793,
  0.3629036247730255,
  0.3502279818058014,
  0.33727866411209106,
  0.32802167534828186,
  0.3186872899532318],
 'val_loss': [0.8418557643890381,
  0.5421290993690491,
  0.4538145959377289,
  0.42091190814971924,
  0.38995546102523804,
  0.36787837743759155,
  0.3639078140258789,
  0.3372558057308197,
  0.32422807812690735,
  0.31440332531929016],
 'lr': [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]}