<a href="https://colab.research.google.com/github/atick-faisal/MultiViewUNet-Aneurysm/blob/main/src/training/MultiViewUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab Info


In [1]:
from psutil import virtual_memory

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)


ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')


Wed Jun 21 11:49:44 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   54C    P0    27W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Fix for GDrive


In [2]:
!pip install -U --no-cache-dir gdown --pre > /dev/null


# Mount GDrive


In [3]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Download and Extract Dataset


In [4]:
!gdown "1NURqp0DrjBpbwI95vJ9bJcZMudvZzLPt"
!unzip -o "AAA_BW_DATASET_r1.zip" > /dev/null


Downloading...
From (uriginal): https://drive.google.com/uc?id=1NURqp0DrjBpbwI95vJ9bJcZMudvZzLPt
From (redirected): https://drive.google.com/uc?id=1NURqp0DrjBpbwI95vJ9bJcZMudvZzLPt&confirm=t&uuid=d791e607-e7c3-4f18-92bf-773a02c8fd8a
To: /content/AAA_BW_DATASET_r1.zip
100% 451M/451M [00:02<00:00, 202MB/s]


# Imports


In [5]:
import os
import datetime
import matplotlib
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm

matplotlib.use('Agg')
plt.rcParams["font.size"] = 16


# Config


In [6]:
PROBLEM = "TAWSS_2_ECAP"

MODEL_NAME = "MultiViewUNet"
DATASET_PATH = "/content/Images/"
TRAIN_DIR = "Train/"
TEST_DIR = "Test/"
INPUT_DIR = PROBLEM.split("_2_")[0]
TARGET_DIR = PROBLEM.split("_2_")[1]
MODEL_PATH = "/content/drive/MyDrive/Research/TAVI/Models/"
PRED_PATH = "/content/drive/MyDrive/Research/TAVI/Predictions/"
IMG_SIZE = 256
BATCH_SIZE = 16
BUFFER_SIZE = 1000
VAL_SPLIT = 0.2
LEARNING_RATE = 0.001
N_EPOCHS = 10
PATIENCE = 30

EXP_NAME = f"{PROBLEM}_{MODEL_NAME}_I{IMG_SIZE}_B{BATCH_SIZE}_LR{LEARNING_RATE}"


# Architecture


In [7]:
class UNet:
    def __init__(
        self,
        img_size: int,
        n_channels: int = 1,
        width: int = 32,
        depth: int = 4,
        kernel_size: int = 3
    ):
        self.img_size = img_size
        self.n_channels = n_channels
        self.width = width
        self.depth = depth
        self.kernel_size = kernel_size

    @staticmethod
    def conv(
        x: tf.Tensor,
        filters: int,
        kernel_size: int
    ) -> tf.Tensor:
        for i in range(2):
            x = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                data_format="channels_last",
                dilation_rate=1,
                groups=1,
                activation=None,
                use_bias=True,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros",
            )(x)

            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def deconv(x: tf.Tensor, filters: int) -> tf.Tensor:
        x = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=2,
            strides=2,
            padding="same",
            output_padding=None,
            data_format=None,
            dilation_rate=1,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
        )(x)

        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def output(x: tf.Tensor) -> tf.Tensor:
        return tf.keras.layers.Conv2D(1, (1, 1), activation="sigmoid")(x)

    @staticmethod
    def pool(x: tf.Tensor, pool_size: int = 2) -> tf.Tensor:
        return tf.keras.layers.MaxPool2D(pool_size)(x)

    @staticmethod
    def dropout(x: tf.Tensor, amount: float = 0.5) -> tf.Tensor:
        return tf.keras.layers.Dropout(amount)(x)

    def __call__(self) -> tf.keras.Model:
        inputs = tf.keras.layers.Input(
            shape=(self.img_size, self.img_size, self.n_channels)
        )

        # scaled = tf.keras.layers.Rescaling(1./255.0, offset=0)(inputs)

        # ------------------ Downsampling ---------------------
        downsample_layers = []
        downsample_layers.append(
            self.conv(
                x=inputs,
                filters=self.width,
                kernel_size=self.kernel_size
            )
        )
        for i in range(1, self.depth):
            dropout_amount = 0.2 if i == 1 else 0.5
            filters = int((2 ** i) * self.width)
            downsample_layers.append(
                self.dropout(
                    self.pool(
                        self.conv(
                            x=downsample_layers[i - 1],
                            filters=filters,
                            kernel_size=self.kernel_size
                        )
                    ),
                    amount=dropout_amount
                )
            )

        # ------------------- Features --------------------
        n_features = int((2 ** self.depth) * self.width)
        self.features = self.pool(
            self.conv(
                x=downsample_layers[-1],
                filters=n_features,
                kernel_size=self.kernel_size
            )
        )

        # ------------------- Upsampling --------------------
        upsample_layers = []
        upsample_layers.append(self.features)
        for i in range(1, self.depth + 1):
            filters = int((2 ** (self.depth - i)) * self.width)
            upsample_layers.append(
                self.conv(
                    x=self.dropout(
                        tf.keras.layers.concatenate([
                            downsample_layers[self.depth - i],
                            self.deconv(
                                x=upsample_layers[i - 1],
                                filters=filters
                            )
                        ])
                    ),
                    filters=filters,
                    kernel_size=self.kernel_size
                )
            )

        # ---------------------- Output -----------------------
        outputs = self.output(upsample_layers[-1])

        return tf.keras.Model(inputs, outputs)


In [8]:
# ... MultiResUNet
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add
from keras.models import Model, model_from_json
from keras.optimizers import Adam
# from keras.layers.advanced_activations import ELU, LeakyReLU
from keras.utils.vis_utils import plot_model


def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='relu', name=None):
    '''
    2D Convolutional layers

    Arguments:
        x {keras layer} -- input layer
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters

    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(1, 1)})
        activation {str} -- activation function (default: {'relu'})
        name {str} -- name of the layer (default: {None})

    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2D(filters, (num_row, num_col), strides=strides,
               padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    if(activation == None):
        return x

    x = Activation(activation, name=name)(x)

    return x


def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None):
    '''
    2D Transposed Convolutional layers

    Arguments:
        x {keras layer} -- input layer
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters

    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(2, 2)})
        name {str} -- name of the layer (default: {None})

    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2DTranspose(filters, (num_row, num_col),
                        strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    return x


def MultiResBlock(U, inp, alpha=1.67):
    '''
    MultiRes Block

    Arguments:
        U {int} -- Number of filters in a corrsponding UNet stage
        inp {keras layer} -- input layer

    Returns:
        [keras layer] -- [output layer]
    '''

    W = alpha * U

    shortcut = inp

    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
                         int(W*0.5), 1, 1, activation=None, padding='same')

    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,
                        activation='relu', padding='same')

    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
    out = BatchNormalization(axis=3)(out)

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out


def ResPath(filters, length, inp):
    '''
    ResPath

    Arguments:
        filters {int} -- [description]
        length {int} -- length of ResPath
        inp {keras layer} -- input layer

    Returns:
        [keras layer] -- [output layer]
    '''

    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out


def MultiResUNet(height, width, n_channels):
    '''
    MultiResUNet

    Arguments:
        height {int} -- height of image
        width {int} -- width of image
        n_channels {int} -- number of channels in image

    Returns:
        [keras model] -- MultiResUNet model
    '''

    inputs = Input((height, width, n_channels))

    mresblock1 = MultiResBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
    mresblock1 = ResPath(32, 4, mresblock1)

    mresblock2 = MultiResBlock(32*2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
    mresblock2 = ResPath(32*2, 3, mresblock2)

    mresblock3 = MultiResBlock(32*4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
    mresblock3 = ResPath(32*4, 2, mresblock3)

    mresblock4 = MultiResBlock(32*8, pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(mresblock4)
    mresblock4 = ResPath(32*8, 1, mresblock4)

    mresblock5 = MultiResBlock(32*16, pool4)

    up6 = concatenate([Conv2DTranspose(
        32*8, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock4], axis=3)
    mresblock6 = MultiResBlock(32*8, up6)

    up7 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock3], axis=3)
    mresblock7 = MultiResBlock(32*4, up7)

    up8 = concatenate([Conv2DTranspose(
        32*2, (2, 2), strides=(2, 2), padding='same')(mresblock7), mresblock2], axis=3)
    mresblock8 = MultiResBlock(32*2, up8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(
        2, 2), padding='same')(mresblock8), mresblock1], axis=3)
    mresblock9 = MultiResBlock(32, up9)

    conv10 = conv2d_bn(mresblock9, 3, 1, 1, activation='sigmoid')

    model = Model(inputs=[inputs], outputs=[conv10])

    return model


# def main():

#     # Define the model

#     model = MultiResUnet(128, 128,3)
#     print(model.summary())


# if __name__ == '__main__':
#     main()


In [9]:
# # ... MultiResUNet3D
# from keras.layers import Input, Conv3D, MaxPooling3D, Conv3DTranspose, concatenate, BatchNormalization, Activation, add
# from keras.models import Model, model_from_json
# from keras.optimizers import Adam
# # from keras.layers.advanced_activations import ELU, LeakyReLU
# from keras.utils.vis_utils import plot_model


# def conv3d_bn(x, filters, num_row, num_col, num_z, padding='same', strides=(1, 1, 1), activation='relu', name=None):
#     '''
#     3D Convolutional layers

#     Arguments:
#         x {keras layer} -- input layer
#         filters {int} -- number of filters
#         num_row {int} -- number of rows in filters
#         num_col {int} -- number of columns in filters
#         num_z {int} -- length along z axis in filters

#     Keyword Arguments:
#         padding {str} -- mode of padding (default: {'same'})
#         strides {tuple} -- stride of convolution operation (default: {(1, 1, 1)})
#         activation {str} -- activation function (default: {'relu'})
#         name {str} -- name of the layer (default: {None})

#     Returns:
#         [keras layer] -- [output layer]
#     '''

#     x = Conv3D(filters, (num_row, num_col, num_z),
#                strides=strides, padding=padding, use_bias=False)(x)
#     x = BatchNormalization(axis=4, scale=False)(x)

#     if(activation == None):
#         return x

#     x = Activation(activation, name=name)(x)
#     return x


# def trans_conv3d_bn(x, filters, num_row, num_col, num_z, padding='same', strides=(2, 2, 2), name=None):
#     '''
#     2D Transposed Convolutional layers

#     Arguments:
#         x {keras layer} -- input layer
#         filters {int} -- number of filters
#         num_row {int} -- number of rows in filters
#         num_col {int} -- number of columns in filters
#         num_z {int} -- length along z axis in filters

#     Keyword Arguments:
#         padding {str} -- mode of padding (default: {'same'})
#         strides {tuple} -- stride of convolution operation (default: {(2, 2, 2)})
#         name {str} -- name of the layer (default: {None})

#     Returns:
#         [keras layer] -- [output layer]
#     '''

#     x = Conv3DTranspose(filters, (num_row, num_col, num_z),
#                         strides=strides, padding=padding)(x)
#     x = BatchNormalization(axis=4, scale=False)(x)

#     return x


# def MultiResBlock(U, inp, alpha=1.67):
#     '''
#     MultiRes Block

#     Arguments:
#         U {int} -- Number of filters in a corrsponding UNet stage
#         inp {keras layer} -- input layer

#     Returns:
#         [keras layer] -- [output layer]
#     '''

#     W = alpha * U

#     shortcut = inp

#     shortcut = conv3d_bn(shortcut, int(W*0.167) + int(W*0.333) +
#                          int(W*0.5), 1, 1, 1, activation=None, padding='same')

#     conv3x3 = conv3d_bn(inp, int(W*0.167), 3, 3, 3,
#                         activation='relu', padding='same')

#     conv5x5 = conv3d_bn(conv3x3, int(W*0.333), 3, 3, 3,
#                         activation='relu', padding='same')

#     conv7x7 = conv3d_bn(conv5x5, int(W*0.5), 3, 3, 3,
#                         activation='relu', padding='same')

#     out = concatenate([conv3x3, conv5x5, conv7x7], axis=4)
#     out = BatchNormalization(axis=4)(out)

#     out = add([shortcut, out])
#     out = Activation('relu')(out)
#     out = BatchNormalization(axis=4)(out)

#     return out


# def ResPath(filters, length, inp):
#     '''
#     ResPath

#     Arguments:
#         filters {int} -- [description]
#         length {int} -- length of ResPath
#         inp {keras layer} -- input layer

#     Returns:
#         [keras layer] -- [output layer]
#     '''

#     shortcut = inp
#     shortcut = conv3d_bn(shortcut, filters, 1, 1, 1,
#                          activation=None, padding='same')

#     out = conv3d_bn(inp, filters, 3, 3, 3, activation='relu', padding='same')

#     out = add([shortcut, out])
#     out = Activation('relu')(out)
#     out = BatchNormalization(axis=4)(out)

#     for i in range(length-1):

#         shortcut = out
#         shortcut = conv3d_bn(shortcut, filters, 1, 1, 1,
#                              activation=None, padding='same')

#         out = conv3d_bn(out, filters, 3, 3, 3,
#                         activation='relu', padding='same')

#         out = add([shortcut, out])
#         out = Activation('relu')(out)
#         out = BatchNormalization(axis=4)(out)

#     return out


# def MultiResUnet3D(height, width, z, n_channels):
#     '''
#     MultiResUNet3D

#     Arguments:
#         height {int} -- height of image
#         width {int} -- width of image
#         z {int} -- length along z axis
#         n_channels {int} -- number of channels in image

#     Returns:
#         [keras model] -- MultiResUNet3D model
#     '''

#     inputs = Input((height, width, z, n_channels))

#     mresblock1 = MultiResBlock(32, inputs)
#     pool1 = MaxPooling3D(pool_size=(2, 2, 2))(mresblock1)
#     mresblock1 = ResPath(32, 4, mresblock1)

#     mresblock2 = MultiResBlock(32*2, pool1)
#     pool2 = MaxPooling3D(pool_size=(2, 2, 2))(mresblock2)
#     mresblock2 = ResPath(32*2, 3, mresblock2)

#     mresblock3 = MultiResBlock(32*4, pool2)
#     pool3 = MaxPooling3D(pool_size=(2, 2, 2))(mresblock3)
#     mresblock3 = ResPath(32*4, 2, mresblock3)

#     mresblock4 = MultiResBlock(32*8, pool3)
#     pool4 = MaxPooling3D(pool_size=(2, 2, 2))(mresblock4)
#     mresblock4 = ResPath(32*8, 1, mresblock4)

#     mresblock5 = MultiResBlock(32*16, pool4)

#     up6 = concatenate([Conv3DTranspose(
#         32*8, (2, 2, 2), strides=(2, 2, 2), padding='same')(mresblock5), mresblock4], axis=4)
#     mresblock6 = MultiResBlock(32*8, up6)

#     up7 = concatenate([Conv3DTranspose(
#         32*4, (2, 2, 2), strides=(2, 2, 2), padding='same')(mresblock6), mresblock3], axis=4)
#     mresblock7 = MultiResBlock(32*4, up7)

#     up8 = concatenate([Conv3DTranspose(
#         32*2, (2, 2, 2), strides=(2, 2, 2), padding='same')(mresblock7), mresblock2], axis=4)
#     mresblock8 = MultiResBlock(32*2, up8)

#     up9 = concatenate([Conv3DTranspose(32, (2, 2, 2), strides=(
#         2, 2, 2), padding='same')(mresblock8), mresblock1], axis=4)
#     mresblock9 = MultiResBlock(32, up9)

#     conv10 = conv3d_bn(mresblock9, 3, 1, 1, 1, activation='sigmoid')

#     model = Model(inputs=[inputs], outputs=[conv10])

#     return model


# # def main():

# #     # Define the model

# #     model = MultiResUnet3D(80, 80, 48, 4)
# #     print(model.summary())


# # if __name__ == '__main__':
# #     main()


# Loss Functions / Metrics


In [10]:
@tf.function
def attention_mse(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.square(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


@tf.function
def attention_mae(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.abs(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


# DataLoader


In [11]:
def load_data_from_dir(path: str) -> tf.data.Dataset:
    return tf.keras.utils.image_dataset_from_directory(
        directory=path,
        labels=None,
        color_mode='grayscale',
        batch_size=BATCH_SIZE,
        image_size=(IMG_SIZE, IMG_SIZE),
        shuffle=False,
        seed=42,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False
    )


# Load Dataset


In [12]:
trainX = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, INPUT_DIR))
trainY = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, TARGET_DIR))
testX = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, INPUT_DIR))
testY = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, TARGET_DIR))

train_ds = tf.data.Dataset.zip((trainX, trainY))
test_ds = tf.data.Dataset.zip((testX, testY))

print(train_ds.element_spec)
print(test_ds.element_spec)


Found 2436 files belonging to 1 classes.
Found 2436 files belonging to 1 classes.
Found 612 files belonging to 1 classes.
Found 612 files belonging to 1 classes.
(TensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, name=None))
(TensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, name=None))


# Normalization


In [13]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))
test_ds = test_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))


# Augmentation


In [14]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)
        self.augment_labels = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels


# Optimization


In [15]:
AUTOTUNE = tf.data.AUTOTUNE

train_batches = (
    train_ds
    .cache()
    .shuffle(BUFFER_SIZE)
    .prefetch(buffer_size=AUTOTUNE)
)

test_batches = (
    test_ds
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)


# Training Config


In [16]:
model_path = os.path.join(MODEL_PATH, EXP_NAME)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        model_path,
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        save_weights_only=True
    )
]

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE
)

model = UNet(IMG_SIZE)()
# model = MultiResUNet(IMG_SIZE, IMG_SIZE, 3)

# model.compile(
#     loss=attention_mse,
#     optimizer=optimizer,
#     metrics=[attention_mae]
# )

model.compile(
    loss="mse",
    optimizer=optimizer,
    metrics=["mae"]
)

# try:
#     model.load_weights(model_path)
# except:
#     print("Checkpoint not found")
#     pass


# Training


In [17]:
history = model.fit(
    train_batches,
    validation_data=test_batches,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/10
Epoch 1: val_loss improved from inf to 0.01280, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 2/10
Epoch 2: val_loss did not improve from 0.01280
Epoch 3/10
Epoch 3: val_loss did not improve from 0.01280
Epoch 4/10
Epoch 4: val_loss did not improve from 0.01280
Epoch 5/10
Epoch 5: val_loss did not improve from 0.01280
Epoch 6/10
Epoch 6: val_loss did not improve from 0.01280
Epoch 7/10
Epoch 7: val_loss did not improve from 0.01280
Epoch 8/10
Epoch 8: val_loss did not improve from 0.01280
Epoch 9/10
Epoch 9: val_loss improved from 0.01280 to 0.01160, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 10/10
Epoch 10: val_loss did not improve from 0.01160


# Save Model


In [18]:
timestamp = datetime.datetime.now().strftime('%b-%d-%I:%M%p')
if not os.path.exists(model_path):
    os.makedirs(model_path)

model.save(os.path.join(model_path, timestamp))




# Save Predictions


In [19]:
test_ds_unbatched = test_batches.unbatch()

pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

metrics = pd.DataFrame()

for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

    target = tf.squeeze(target).numpy()
    prediction = tf.squeeze(
        model.predict(
            tf.expand_dims(input, axis=0)
        )
    ).numpy()

    # channel_sum = tf.expand_dims(tf.reduce_sum(target, axis=-1), axis=-1)
    # white_mask = tf.reduce_all(tf.equal(channel_sum, 3.0), axis=-1)
    # expanded_mask = tf.expand_dims(white_mask, axis=-1)
    # # expanded_mask = tf.tile(expanded_mask, [1, 1, 3])
    # expanded_mask = tf.tile(expanded_mask, [1, 1])
    # prediction = tf.where(expanded_mask, tf.ones_like(prediction), prediction)

    plt.figure(figsize=(7, 7))
    plt.imshow(target, cmap="gray")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_T.png"))
    plt.close()

    plt.figure(figsize=(7, 7))
    plt.imshow(prediction, cmap="gray")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_P.png"))
    plt.close()



0it [00:00, ?it/s]



1it [00:02,  2.50s/it]



2it [00:02,  1.32s/it]



3it [00:03,  1.05it/s]



4it [00:03,  1.30it/s]



5it [00:04,  1.49it/s]



6it [00:05,  1.58it/s]



7it [00:05,  1.57it/s]



8it [00:06,  1.67it/s]



9it [00:06,  1.78it/s]



10it [00:07,  1.83it/s]



11it [00:07,  1.69it/s]



12it [00:08,  1.55it/s]



13it [00:09,  1.51it/s]



14it [00:10,  1.46it/s]



15it [00:10,  1.41it/s]



16it [00:11,  1.38it/s]



17it [00:12,  1.49it/s]



18it [00:12,  1.62it/s]



19it [00:13,  1.72it/s]



20it [00:13,  1.79it/s]



21it [00:14,  1.87it/s]



22it [00:14,  1.92it/s]



23it [00:15,  1.94it/s]



24it [00:15,  1.94it/s]



25it [00:16,  1.97it/s]



26it [00:16,  2.00it/s]



27it [00:17,  1.97it/s]



28it [00:17,  1.98it/s]



29it [00:18,  2.00it/s]



30it [00:18,  2.01it/s]



31it [00:19,  2.00it/s]



32it [00:19,  2.00it/s]



33it [00:20,  1.98it/s]



34it [00:20,  2.01it/s]



35it [00:21,  2.01it/s]



36it [00:21,  2.04it/s]



37it [00:22,  1.91it/s]



38it [00:22,  1.74it/s]



39it [00:23,  1.61it/s]



40it [00:24,  1.52it/s]



41it [00:25,  1.44it/s]



42it [00:25,  1.37it/s]



43it [00:26,  1.38it/s]



44it [00:27,  1.53it/s]



45it [00:27,  1.66it/s]



46it [00:28,  1.74it/s]



47it [00:28,  1.82it/s]



48it [00:29,  1.84it/s]



49it [00:29,  1.90it/s]



50it [00:30,  1.93it/s]



51it [00:30,  1.96it/s]



52it [00:31,  1.96it/s]



53it [00:31,  1.98it/s]



54it [00:32,  1.96it/s]



55it [00:32,  1.98it/s]



56it [00:33,  1.96it/s]



57it [00:33,  1.98it/s]



58it [00:34,  2.00it/s]



59it [00:34,  1.98it/s]



60it [00:35,  1.98it/s]



61it [00:35,  1.99it/s]



62it [00:36,  1.99it/s]



63it [00:36,  1.87it/s]



64it [00:37,  1.71it/s]



65it [00:38,  1.58it/s]



66it [00:39,  1.49it/s]



67it [00:39,  1.42it/s]



68it [00:40,  1.39it/s]



69it [00:41,  1.39it/s]



70it [00:41,  1.54it/s]



71it [00:42,  1.65it/s]



72it [00:42,  1.71it/s]



73it [00:43,  1.81it/s]



74it [00:43,  1.86it/s]



75it [00:44,  1.93it/s]



76it [00:44,  1.97it/s]



77it [00:45,  2.01it/s]



78it [00:45,  2.02it/s]



79it [00:46,  2.04it/s]



80it [00:46,  2.02it/s]



81it [00:47,  2.04it/s]



82it [00:47,  2.06it/s]



83it [00:48,  2.05it/s]



84it [00:48,  2.05it/s]



85it [00:49,  2.04it/s]



86it [00:49,  2.00it/s]



87it [00:50,  2.00it/s]



88it [00:50,  2.01it/s]



89it [00:51,  2.02it/s]



90it [00:52,  1.31it/s]



91it [00:53,  1.34it/s]



92it [00:54,  1.32it/s]



93it [00:54,  1.33it/s]



94it [00:55,  1.33it/s]



95it [00:56,  1.39it/s]



96it [00:56,  1.52it/s]



97it [00:57,  1.61it/s]



98it [00:57,  1.70it/s]



99it [00:58,  1.74it/s]



100it [00:58,  1.79it/s]



101it [00:59,  1.82it/s]



102it [00:59,  1.87it/s]



103it [01:00,  1.90it/s]



104it [01:00,  1.89it/s]



105it [01:01,  1.90it/s]



106it [01:01,  1.92it/s]



107it [01:02,  1.92it/s]



108it [01:02,  1.91it/s]



109it [01:03,  1.94it/s]



110it [01:03,  1.91it/s]



111it [01:04,  1.91it/s]



112it [01:04,  1.95it/s]



113it [01:05,  1.94it/s]



114it [01:06,  1.86it/s]



115it [01:06,  1.63it/s]



116it [01:07,  1.56it/s]



117it [01:08,  1.44it/s]



118it [01:09,  1.40it/s]



119it [01:09,  1.34it/s]



120it [01:10,  1.35it/s]



121it [01:11,  1.44it/s]



122it [01:11,  1.56it/s]



123it [01:12,  1.67it/s]



124it [01:12,  1.78it/s]



125it [01:13,  1.87it/s]



126it [01:13,  1.90it/s]



127it [01:14,  1.97it/s]



128it [01:14,  1.95it/s]



129it [01:15,  1.98it/s]



130it [01:15,  2.00it/s]



131it [01:16,  1.95it/s]



132it [01:16,  1.94it/s]



133it [01:17,  1.98it/s]



134it [01:17,  1.94it/s]



135it [01:18,  1.92it/s]



136it [01:18,  1.92it/s]



137it [01:19,  1.91it/s]



138it [01:19,  1.91it/s]



139it [01:20,  1.94it/s]



140it [01:21,  1.81it/s]



141it [01:21,  1.68it/s]



142it [01:22,  1.56it/s]



143it [01:23,  1.51it/s]



144it [01:24,  1.41it/s]



145it [01:24,  1.45it/s]



146it [01:25,  1.46it/s]



147it [01:25,  1.60it/s]



148it [01:26,  1.70it/s]



149it [01:26,  1.79it/s]



150it [01:27,  1.85it/s]



151it [01:27,  1.91it/s]



152it [01:28,  1.88it/s]



153it [01:28,  1.92it/s]



154it [01:29,  1.96it/s]



155it [01:29,  2.00it/s]



156it [01:30,  2.01it/s]



157it [01:30,  2.01it/s]



158it [01:31,  1.98it/s]



159it [01:31,  1.97it/s]



160it [01:32,  1.91it/s]



161it [01:32,  1.94it/s]



162it [01:33,  1.95it/s]



163it [01:33,  1.98it/s]



164it [01:34,  2.00it/s]



165it [01:34,  1.98it/s]



166it [01:35,  2.00it/s]



167it [01:36,  1.72it/s]



168it [01:36,  1.55it/s]



169it [01:37,  1.46it/s]



170it [01:38,  1.44it/s]



171it [01:39,  1.40it/s]



172it [01:39,  1.39it/s]



173it [01:40,  1.51it/s]



174it [01:40,  1.63it/s]



175it [01:41,  1.73it/s]



176it [01:41,  1.80it/s]



177it [01:42,  1.87it/s]



178it [01:42,  1.92it/s]



179it [01:43,  1.97it/s]



180it [01:43,  1.99it/s]



181it [01:44,  1.99it/s]



182it [01:44,  1.99it/s]



183it [01:45,  1.99it/s]



184it [01:45,  2.00it/s]



185it [01:46,  2.02it/s]



186it [01:46,  1.99it/s]



187it [01:47,  2.01it/s]



188it [01:47,  2.04it/s]



189it [01:48,  2.06it/s]



190it [01:48,  2.04it/s]



191it [01:49,  2.00it/s]



192it [01:49,  1.99it/s]



193it [01:50,  1.78it/s]



194it [01:51,  1.31it/s]



195it [01:52,  1.30it/s]



196it [01:53,  1.33it/s]



197it [01:54,  1.33it/s]



198it [01:54,  1.36it/s]



199it [01:55,  1.50it/s]



200it [01:55,  1.59it/s]



201it [01:56,  1.66it/s]



202it [01:56,  1.71it/s]



203it [01:57,  1.80it/s]



204it [01:57,  1.82it/s]



205it [01:58,  1.83it/s]



206it [01:58,  1.87it/s]



207it [01:59,  1.92it/s]



208it [01:59,  1.96it/s]



209it [02:00,  1.94it/s]



210it [02:00,  1.98it/s]



211it [02:01,  1.95it/s]



212it [02:01,  2.00it/s]



213it [02:02,  2.00it/s]



214it [02:02,  2.02it/s]



215it [02:03,  2.02it/s]



216it [02:03,  2.04it/s]



217it [02:04,  2.05it/s]



218it [02:05,  1.87it/s]



219it [02:05,  1.74it/s]



220it [02:06,  1.60it/s]



221it [02:07,  1.50it/s]



222it [02:07,  1.44it/s]



223it [02:08,  1.40it/s]



224it [02:09,  1.43it/s]



225it [02:09,  1.56it/s]



226it [02:10,  1.64it/s]



227it [02:10,  1.70it/s]



228it [02:11,  1.79it/s]



229it [02:11,  1.85it/s]



230it [02:12,  1.90it/s]



231it [02:12,  1.92it/s]



232it [02:13,  1.95it/s]



233it [02:13,  1.97it/s]



234it [02:14,  2.00it/s]



235it [02:14,  1.98it/s]



236it [02:15,  1.92it/s]



237it [02:16,  1.95it/s]



238it [02:16,  1.98it/s]



239it [02:16,  1.99it/s]



240it [02:17,  2.00it/s]



241it [02:17,  2.01it/s]



242it [02:18,  1.96it/s]



243it [02:19,  1.91it/s]



244it [02:19,  1.76it/s]



245it [02:20,  1.57it/s]



246it [02:21,  1.53it/s]



247it [02:22,  1.44it/s]



248it [02:22,  1.38it/s]



249it [02:23,  1.33it/s]



250it [02:24,  1.43it/s]



251it [02:24,  1.56it/s]



252it [02:25,  1.62it/s]



253it [02:25,  1.74it/s]



254it [02:26,  1.83it/s]



255it [02:26,  1.85it/s]



256it [02:27,  1.91it/s]



257it [02:27,  1.97it/s]



258it [02:28,  1.94it/s]



259it [02:28,  1.91it/s]



260it [02:29,  1.94it/s]



261it [02:29,  1.99it/s]



262it [02:30,  2.03it/s]



263it [02:30,  2.03it/s]



264it [02:31,  2.03it/s]



265it [02:31,  2.00it/s]



266it [02:32,  1.97it/s]



267it [02:32,  1.94it/s]



268it [02:33,  1.97it/s]



269it [02:33,  1.99it/s]



270it [02:34,  1.81it/s]



271it [02:35,  1.60it/s]



272it [02:36,  1.47it/s]



273it [02:36,  1.43it/s]



274it [02:37,  1.39it/s]



275it [02:38,  1.37it/s]



276it [02:38,  1.47it/s]



277it [02:39,  1.58it/s]



278it [02:39,  1.65it/s]



279it [02:40,  1.73it/s]



280it [02:40,  1.78it/s]



281it [02:41,  1.83it/s]



282it [02:41,  1.88it/s]



283it [02:42,  1.94it/s]



284it [02:42,  1.96it/s]



285it [02:43,  1.94it/s]



286it [02:44,  1.95it/s]



287it [02:44,  1.95it/s]



288it [02:45,  1.96it/s]



289it [02:45,  1.98it/s]



290it [02:46,  1.94it/s]



291it [02:46,  1.94it/s]



292it [02:47,  1.98it/s]



293it [02:47,  2.00it/s]



294it [02:48,  2.00it/s]



295it [02:48,  1.91it/s]



296it [02:49,  1.34it/s]



297it [02:50,  1.35it/s]



298it [02:51,  1.34it/s]



299it [02:52,  1.34it/s]



300it [02:52,  1.36it/s]



301it [02:53,  1.46it/s]



302it [02:53,  1.59it/s]



303it [02:54,  1.68it/s]



304it [02:54,  1.78it/s]



305it [02:55,  1.83it/s]



306it [02:55,  1.83it/s]



307it [02:56,  1.88it/s]



308it [02:56,  1.87it/s]



309it [02:57,  1.88it/s]



310it [02:58,  1.87it/s]



311it [02:58,  1.92it/s]



312it [02:59,  1.95it/s]



313it [02:59,  1.89it/s]



314it [03:00,  1.92it/s]



315it [03:00,  1.92it/s]



316it [03:01,  1.89it/s]



317it [03:01,  1.91it/s]



318it [03:02,  1.90it/s]



319it [03:02,  1.93it/s]



320it [03:03,  1.74it/s]



321it [03:04,  1.59it/s]



322it [03:04,  1.57it/s]



323it [03:05,  1.51it/s]



324it [03:06,  1.49it/s]



325it [03:06,  1.47it/s]



326it [03:07,  1.45it/s]



327it [03:08,  1.57it/s]



328it [03:08,  1.67it/s]



329it [03:09,  1.73it/s]



330it [03:09,  1.82it/s]



331it [03:10,  1.82it/s]



332it [03:10,  1.87it/s]



333it [03:11,  1.91it/s]



334it [03:11,  1.91it/s]



335it [03:12,  1.96it/s]



336it [03:12,  1.98it/s]



337it [03:13,  1.97it/s]



338it [03:13,  1.96it/s]



339it [03:14,  1.97it/s]



340it [03:14,  1.92it/s]



341it [03:15,  1.95it/s]



342it [03:15,  1.91it/s]



343it [03:16,  1.86it/s]



344it [03:16,  1.90it/s]



345it [03:17,  1.90it/s]



346it [03:18,  1.80it/s]



347it [03:18,  1.56it/s]



348it [03:19,  1.43it/s]



349it [03:20,  1.35it/s]



350it [03:21,  1.35it/s]



351it [03:22,  1.32it/s]



352it [03:22,  1.44it/s]



353it [03:23,  1.57it/s]



354it [03:23,  1.68it/s]



355it [03:24,  1.75it/s]



356it [03:24,  1.82it/s]



357it [03:25,  1.87it/s]



358it [03:25,  1.88it/s]



359it [03:26,  1.91it/s]



360it [03:26,  1.87it/s]



361it [03:27,  1.86it/s]



362it [03:27,  1.91it/s]



363it [03:28,  1.93it/s]



364it [03:28,  1.95it/s]



365it [03:29,  1.91it/s]



366it [03:29,  1.97it/s]



367it [03:30,  1.98it/s]



368it [03:30,  1.99it/s]



369it [03:31,  2.00it/s]



370it [03:31,  1.96it/s]



371it [03:32,  1.89it/s]



372it [03:33,  1.67it/s]



373it [03:34,  1.51it/s]



374it [03:34,  1.47it/s]



375it [03:35,  1.44it/s]



376it [03:36,  1.41it/s]



377it [03:36,  1.42it/s]



378it [03:37,  1.51it/s]



379it [03:37,  1.64it/s]



380it [03:38,  1.75it/s]



381it [03:38,  1.83it/s]



382it [03:39,  1.85it/s]



383it [03:39,  1.89it/s]



384it [03:40,  1.84it/s]



385it [03:41,  1.86it/s]



386it [03:41,  1.84it/s]



387it [03:42,  1.83it/s]



388it [03:42,  1.87it/s]



389it [03:43,  1.92it/s]



390it [03:43,  1.92it/s]



391it [03:44,  1.93it/s]



392it [03:44,  1.95it/s]



393it [03:45,  1.93it/s]



394it [03:45,  1.91it/s]



395it [03:46,  1.93it/s]



396it [03:46,  1.95it/s]



397it [03:47,  1.43it/s]



398it [03:48,  1.37it/s]



399it [03:49,  1.32it/s]



400it [03:50,  1.30it/s]



401it [03:51,  1.28it/s]



402it [03:51,  1.32it/s]



403it [03:52,  1.47it/s]



404it [03:52,  1.58it/s]



405it [03:53,  1.69it/s]



406it [03:53,  1.74it/s]



407it [03:54,  1.81it/s]



408it [03:54,  1.86it/s]



409it [03:55,  1.87it/s]



410it [03:55,  1.89it/s]



411it [03:56,  1.93it/s]



412it [03:56,  1.97it/s]



413it [03:57,  1.93it/s]



414it [03:57,  1.97it/s]



415it [03:58,  2.00it/s]



416it [03:58,  2.01it/s]



417it [03:59,  2.00it/s]



418it [03:59,  2.03it/s]



419it [04:00,  2.05it/s]



420it [04:00,  2.01it/s]



421it [04:01,  2.01it/s]



422it [04:01,  1.96it/s]



423it [04:02,  1.69it/s]



424it [04:03,  1.58it/s]



425it [04:04,  1.50it/s]



426it [04:04,  1.45it/s]



427it [04:05,  1.39it/s]



428it [04:06,  1.38it/s]



429it [04:07,  1.48it/s]



430it [04:07,  1.58it/s]



431it [04:08,  1.70it/s]



432it [04:08,  1.79it/s]



433it [04:09,  1.79it/s]



434it [04:09,  1.83it/s]



435it [04:10,  1.87it/s]



436it [04:10,  1.83it/s]



437it [04:11,  1.90it/s]



438it [04:11,  1.93it/s]



439it [04:12,  1.98it/s]



440it [04:12,  2.01it/s]



441it [04:13,  1.96it/s]



442it [04:13,  1.94it/s]



443it [04:14,  1.95it/s]



444it [04:14,  1.94it/s]



445it [04:15,  1.96it/s]



446it [04:15,  1.90it/s]



447it [04:16,  1.88it/s]



448it [04:16,  1.79it/s]



449it [04:17,  1.66it/s]



450it [04:18,  1.56it/s]



451it [04:19,  1.46it/s]



452it [04:19,  1.41it/s]



453it [04:20,  1.38it/s]



454it [04:21,  1.38it/s]



455it [04:21,  1.50it/s]



456it [04:22,  1.60it/s]



457it [04:23,  1.65it/s]



458it [04:23,  1.70it/s]



459it [04:24,  1.79it/s]



460it [04:24,  1.81it/s]



461it [04:25,  1.83it/s]



462it [04:25,  1.90it/s]



463it [04:26,  1.88it/s]



464it [04:26,  1.92it/s]



465it [04:27,  1.93it/s]



466it [04:27,  1.96it/s]



467it [04:28,  1.96it/s]



468it [04:28,  1.99it/s]



469it [04:29,  1.94it/s]



470it [04:29,  1.90it/s]



471it [04:30,  1.92it/s]



472it [04:30,  1.90it/s]



473it [04:31,  1.87it/s]



474it [04:32,  1.60it/s]



475it [04:33,  1.45it/s]



476it [04:33,  1.37it/s]



477it [04:34,  1.33it/s]



478it [04:35,  1.30it/s]



479it [04:36,  1.35it/s]



480it [04:36,  1.46it/s]



481it [04:37,  1.58it/s]



482it [04:37,  1.68it/s]



483it [04:38,  1.73it/s]



484it [04:38,  1.80it/s]



485it [04:39,  1.87it/s]



486it [04:39,  1.90it/s]



487it [04:40,  1.89it/s]



488it [04:40,  1.90it/s]



489it [04:41,  1.91it/s]



490it [04:41,  1.96it/s]



491it [04:42,  1.95it/s]



492it [04:42,  1.95it/s]



493it [04:43,  1.95it/s]



494it [04:43,  1.92it/s]



495it [04:44,  1.93it/s]



496it [04:44,  1.95it/s]



497it [04:45,  1.93it/s]



498it [04:45,  1.96it/s]



499it [04:46,  1.79it/s]



500it [04:47,  1.69it/s]



501it [04:48,  1.57it/s]



502it [04:49,  1.24it/s]



503it [04:50,  1.25it/s]



504it [04:50,  1.30it/s]



505it [04:51,  1.44it/s]



506it [04:51,  1.58it/s]



507it [04:52,  1.68it/s]



508it [04:52,  1.77it/s]



509it [04:53,  1.84it/s]



510it [04:53,  1.89it/s]



511it [04:54,  1.91it/s]



512it [04:54,  1.94it/s]



513it [04:55,  1.92it/s]



514it [04:55,  1.97it/s]



515it [04:56,  1.95it/s]



516it [04:56,  1.93it/s]



517it [04:57,  1.93it/s]



518it [04:57,  1.90it/s]



519it [04:58,  1.86it/s]



520it [04:58,  1.86it/s]



521it [04:59,  1.88it/s]



522it [04:59,  1.93it/s]



523it [05:00,  1.94it/s]



524it [05:01,  1.86it/s]



525it [05:01,  1.64it/s]



526it [05:02,  1.51it/s]



527it [05:03,  1.42it/s]



528it [05:04,  1.40it/s]



529it [05:04,  1.39it/s]



530it [05:05,  1.36it/s]



531it [05:06,  1.52it/s]



532it [05:06,  1.64it/s]



533it [05:07,  1.71it/s]



534it [05:07,  1.76it/s]



535it [05:08,  1.81it/s]



536it [05:08,  1.83it/s]



537it [05:09,  1.91it/s]



538it [05:09,  1.96it/s]



539it [05:10,  1.98it/s]



540it [05:10,  1.93it/s]



541it [05:11,  1.98it/s]



542it [05:11,  1.92it/s]



543it [05:12,  1.92it/s]



544it [05:12,  1.93it/s]



545it [05:13,  1.96it/s]



546it [05:13,  1.97it/s]



547it [05:14,  1.95it/s]



548it [05:14,  1.96it/s]



549it [05:15,  2.00it/s]



550it [05:15,  1.81it/s]



551it [05:16,  1.66it/s]



552it [05:17,  1.56it/s]



553it [05:18,  1.49it/s]



554it [05:18,  1.46it/s]



555it [05:19,  1.40it/s]



556it [05:20,  1.43it/s]



557it [05:20,  1.57it/s]



558it [05:21,  1.66it/s]



559it [05:21,  1.77it/s]



560it [05:22,  1.83it/s]



561it [05:22,  1.89it/s]



562it [05:23,  1.89it/s]



563it [05:23,  1.93it/s]



564it [05:24,  1.93it/s]



565it [05:24,  1.96it/s]



566it [05:25,  1.98it/s]



567it [05:25,  1.95it/s]



568it [05:26,  1.91it/s]



569it [05:26,  1.90it/s]



570it [05:27,  1.91it/s]



571it [05:27,  1.91it/s]



572it [05:28,  1.93it/s]



573it [05:28,  1.95it/s]



574it [05:29,  1.97it/s]



575it [05:30,  1.96it/s]



576it [05:30,  1.78it/s]



577it [05:31,  1.60it/s]



578it [05:32,  1.51it/s]



579it [05:32,  1.45it/s]



580it [05:33,  1.42it/s]



581it [05:34,  1.43it/s]



582it [05:35,  1.45it/s]



583it [05:35,  1.58it/s]



584it [05:36,  1.66it/s]



585it [05:36,  1.75it/s]



586it [05:37,  1.83it/s]



587it [05:37,  1.90it/s]



588it [05:38,  1.89it/s]



589it [05:38,  1.90it/s]



590it [05:39,  1.95it/s]



591it [05:39,  1.99it/s]



592it [05:40,  2.01it/s]



593it [05:40,  2.01it/s]



594it [05:41,  1.96it/s]



595it [05:41,  1.99it/s]



596it [05:42,  1.97it/s]



597it [05:42,  1.96it/s]



598it [05:43,  1.98it/s]



599it [05:43,  1.99it/s]



600it [05:44,  1.96it/s]



601it [05:44,  2.00it/s]



602it [05:45,  1.88it/s]



603it [05:45,  1.72it/s]



604it [05:46,  1.59it/s]



605it [05:47,  1.22it/s]



606it [05:48,  1.19it/s]



607it [05:49,  1.22it/s]



608it [05:50,  1.36it/s]



609it [05:50,  1.49it/s]



610it [05:51,  1.56it/s]



611it [05:51,  1.67it/s]



612it [05:52,  1.74it/s]


# Loss Curve


In [20]:
try:
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    plt.figure(figsize=(8, 6))
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Attention MAE')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.tight_layout()
    plt.savefig(os.path.join(model_path, timestamp + ".png"))
    plt.close()
except:
    print("Model did not finish training")


# Metrics


In [21]:
model.evaluate(test_batches)




[0.022757774218916893, 0.03971075266599655]