<a href="https://colab.research.google.com/github/atick-faisal/TAVI/blob/main/src/training/MultiViewUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab Info

In [1]:
from psutil import virtual_memory

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)


ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')


Sun Apr  2 23:14:32 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   67C    P8    11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Fix for GDrive

In [2]:
!pip install -U --no-cache-dir gdown --pre > /dev/null


# Mount GDrive

In [3]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


# Download and Extract Dataset

In [4]:
!gdown "1Hkn-xh9gqjWOCkFA0OwzvWPDRsIPwdSv"
!unzip -o "AAA_DATASET_v7.zip" > /dev/null


Downloading...
From (uriginal): https://drive.google.com/uc?id=1Hkn-xh9gqjWOCkFA0OwzvWPDRsIPwdSv
From (redirected): https://drive.google.com/uc?id=1Hkn-xh9gqjWOCkFA0OwzvWPDRsIPwdSv&confirm=t&uuid=8e64e37a-a14a-48a9-8d8c-0fed60dbbc27
To: /content/AAA_DATASET_v7.zip
100% 3.42G/3.42G [00:29<00:00, 118MB/s]


# Imports

In [5]:
import os
import datetime
import matplotlib
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm

matplotlib.use('Agg')
plt.rcParams["font.size"] = 16

# Config

In [6]:
PROBLEM = "TAWSS_2_ECAP"

MODEL_NAME = "MultiViewUNet"
DATASET_PATH = "/content/Images/"
TRAIN_DIR = "Train/"
TEST_DIR = "Test/"
INPUT_DIR = PROBLEM.split("_2_")[0]
TARGET_DIR = PROBLEM.split("_2_")[1]
MODEL_PATH = "/content/drive/MyDrive/Research/TAVI/Models/"
PRED_PATH = "/content/drive/MyDrive/Research/TAVI/Predictions/"
IMG_SIZE = 256
BATCH_SIZE = 16
BUFFER_SIZE = 1000
VAL_SPLIT = 0.2
LEARNING_RATE = 0.001
N_EPOCHS = 300
PATIENCE = 30

EXP_NAME = f"{PROBLEM}_{MODEL_NAME}_I{IMG_SIZE}_B{BATCH_SIZE}_LR{LEARNING_RATE}"


# Architecture

In [7]:
class UNet:
    def __init__(
        self,
        img_size: int,
        n_channels: int = 3,
        width: int = 32,
        depth: int = 4,
        kernel_size: int = 3
    ):
        self.img_size = img_size
        self.n_channels = n_channels
        self.width = width
        self.depth = depth
        self.kernel_size = kernel_size

    @staticmethod
    def conv(
        x: tf.Tensor,
        filters: int,
        kernel_size: int
    ) -> tf.Tensor:
        for i in range(2):
            x = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                data_format="channels_last",
                dilation_rate=1,
                groups=1,
                activation=None,
                use_bias=True,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros",
            )(x)

            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def deconv(x: tf.Tensor, filters: int) -> tf.Tensor:
        x = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=2,
            strides=2,
            padding="same",
            output_padding=None,
            data_format=None,
            dilation_rate=1,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
        )(x)

        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def output(x: tf.Tensor) -> tf.Tensor:
        return tf.keras.layers.Conv2D(3, (1, 1), activation="sigmoid")(x)

    @staticmethod
    def pool(x: tf.Tensor, pool_size: int = 2) -> tf.Tensor:
        return tf.keras.layers.MaxPool2D(pool_size)(x)

    @staticmethod
    def dropout(x: tf.Tensor, amount: float = 0.5) -> tf.Tensor:
        return tf.keras.layers.Dropout(amount)(x)

    def __call__(self) -> tf.keras.Model:
        inputs = tf.keras.layers.Input(
            shape=(self.img_size, self.img_size, self.n_channels)
        )

        # scaled = tf.keras.layers.Rescaling(1./255.0, offset=0)(inputs)

        # ------------------ Downsampling ---------------------
        downsample_layers = []
        downsample_layers.append(
            self.conv(
                x=inputs,
                filters=self.width,
                kernel_size=self.kernel_size
            )
        )
        for i in range(1, self.depth):
            dropout_amount = 0.2 if i == 1 else 0.5
            filters = int((2 ** i) * self.width)
            downsample_layers.append(
                self.dropout(
                    self.pool(
                        self.conv(
                            x=downsample_layers[i - 1],
                            filters=filters,
                            kernel_size=self.kernel_size
                        )
                    ),
                    amount=dropout_amount
                )
            )

        # ------------------- Features --------------------
        n_features = int((2 ** self.depth) * self.width)
        self.features = self.pool(
            self.conv(
                x=downsample_layers[-1],
                filters=n_features,
                kernel_size=self.kernel_size
            )
        )

        # ------------------- Upsampling --------------------
        upsample_layers = []
        upsample_layers.append(self.features)
        for i in range(1, self.depth + 1):
            filters = int((2 ** (self.depth - i)) * self.width)
            upsample_layers.append(
                self.conv(
                    x=self.dropout(
                        tf.keras.layers.concatenate([
                            downsample_layers[self.depth - i],
                            self.deconv(
                                x=upsample_layers[i - 1],
                                filters=filters
                            )
                        ])
                    ),
                    filters=filters,
                    kernel_size=self.kernel_size
                )
            )

        # ---------------------- Output -----------------------
        outputs = self.output(upsample_layers[-1])

        return tf.keras.Model(inputs, outputs)


# Loss Functions / Metrics

In [8]:
def attention_mse(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.square(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


def attention_mae(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.abs(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


# DataLoader

In [9]:
def load_data_from_dir(path: str) -> tf.data.Dataset:
    return tf.keras.utils.image_dataset_from_directory(
        directory=path,
        labels=None,
        color_mode='rgb',
        batch_size=BATCH_SIZE,
        image_size=(IMG_SIZE, IMG_SIZE),
        shuffle=False,
        seed=42,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False
    )


# Load Dataset

In [10]:
trainX = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, INPUT_DIR))
trainY = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, TARGET_DIR))
testX = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, INPUT_DIR))
testY = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, TARGET_DIR))

train_ds = tf.data.Dataset.zip((trainX, trainY))
test_ds = tf.data.Dataset.zip((testX, testY))

print(train_ds.element_spec)
print(test_ds.element_spec)


Found 8208 files belonging to 1 classes.
Found 8208 files belonging to 1 classes.
Found 312 files belonging to 1 classes.
Found 312 files belonging to 1 classes.
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))


# Normalization

In [11]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))
test_ds = test_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))


# Augmentation

In [12]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)
        self.augment_labels = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels


# Optimization

In [13]:
AUTOTUNE = tf.data.AUTOTUNE

train_batches = (
    train_ds
    .cache()
    .shuffle(BUFFER_SIZE)
    .prefetch(buffer_size=AUTOTUNE)
)

test_batches = (
    test_ds
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)


# Training Config

In [14]:
model_path = os.path.join(MODEL_PATH, EXP_NAME)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        model_path,
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        save_weights_only=True
    )
]

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE
)

model = UNet(IMG_SIZE)()

model.compile(
    loss=attention_mse,
    optimizer=optimizer,
    metrics=[attention_mae]
)

# try:
#     model.load_weights(model_path)
# except:
#     print("Checkpoint not found")
#     pass


# Training

In [15]:
history = model.fit(
    train_batches,
    validation_data=test_batches,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/300
Epoch 1: val_loss improved from inf to 0.03865, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 2/300
Epoch 2: val_loss improved from 0.03865 to 0.03303, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 3/300
Epoch 3: val_loss did not improve from 0.03303
Epoch 4/300
Epoch 4: val_loss improved from 0.03303 to 0.03278, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 5/300
Epoch 5: val_loss did not improve from 0.03278
Epoch 6/300
Epoch 6: val_loss did not improve from 0.03278
Epoch 7/300
Epoch 7: val_loss did not improve from 0.03278
Epoch 8/300
Epoch 8: val_loss improved from 0.03278 to 0.02756, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 9/300
Epoch 9: val_loss did not improve from 0.02756
Epoch 10/300
Epoch 10: val_loss impro

# Save Model

In [16]:
timestamp = datetime.datetime.now().strftime('%b-%d-%I:%M%p')
if not os.path.exists(model_path):
    os.makedirs(model_path)

model.save(os.path.join(model_path, timestamp))



# Save Predictions

In [17]:
test_ds_unbatched = test_batches.unbatch()

pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

metrics = pd.DataFrame()

for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

    target = tf.squeeze(target).numpy()
    prediction = tf.squeeze(
        model.predict(
            tf.expand_dims(input, axis=0)
        )
    ).numpy()
    prediction[target == 1.0] = 1.0

    plt.figure(figsize=(7, 7))
    plt.imshow(target)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_T.png"))
    plt.close()

    plt.figure(figsize=(7, 7))
    plt.imshow(prediction)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_P.png"))
    plt.close()


0it [00:00, ?it/s]



1it [00:03,  3.51s/it]



2it [00:04,  1.89s/it]



3it [00:04,  1.22s/it]



4it [00:05,  1.10it/s]



5it [00:05,  1.36it/s]



6it [00:06,  1.56it/s]



7it [00:06,  1.74it/s]



8it [00:06,  1.89it/s]



9it [00:07,  2.12it/s]



10it [00:07,  2.29it/s]



11it [00:07,  2.43it/s]



12it [00:08,  2.40it/s]



13it [00:08,  2.52it/s]



14it [00:09,  2.60it/s]



15it [00:09,  2.65it/s]



16it [00:09,  2.64it/s]



17it [00:10,  2.53it/s]



18it [00:10,  2.50it/s]



19it [00:11,  2.46it/s]



20it [00:11,  2.52it/s]



21it [00:11,  2.59it/s]



22it [00:12,  2.66it/s]



23it [00:12,  2.72it/s]



24it [00:12,  2.62it/s]



25it [00:13,  2.51it/s]



26it [00:13,  2.42it/s]



27it [00:14,  2.35it/s]



28it [00:14,  2.27it/s]



29it [00:15,  2.26it/s]



30it [00:15,  2.27it/s]



31it [00:16,  2.26it/s]



32it [00:16,  2.27it/s]



33it [00:16,  2.35it/s]



34it [00:17,  2.41it/s]



35it [00:17,  2.41it/s]



36it [00:18,  2.41it/s]



37it [00:18,  2.40it/s]



38it [00:18,  2.39it/s]



39it [00:19,  2.51it/s]



40it [00:19,  2.62it/s]



41it [00:20,  2.69it/s]



42it [00:20,  2.64it/s]



43it [00:20,  2.56it/s]



44it [00:21,  2.63it/s]



45it [00:21,  2.69it/s]



46it [00:21,  2.67it/s]



47it [00:22,  2.72it/s]



48it [00:22,  2.63it/s]



49it [00:23,  2.56it/s]



50it [00:23,  2.51it/s]



51it [00:23,  2.53it/s]



52it [00:24,  2.48it/s]



53it [00:24,  2.45it/s]



54it [00:25,  2.55it/s]



55it [00:25,  2.51it/s]



56it [00:25,  2.47it/s]



57it [00:26,  2.43it/s]



58it [00:26,  2.39it/s]



59it [00:27,  2.36it/s]



60it [00:27,  2.38it/s]



61it [00:28,  2.38it/s]



62it [00:28,  2.34it/s]



63it [00:28,  2.31it/s]



64it [00:29,  2.30it/s]



65it [00:29,  2.41it/s]



66it [00:30,  2.40it/s]



67it [00:30,  2.44it/s]



68it [00:30,  2.52it/s]



69it [00:31,  2.61it/s]



70it [00:31,  2.68it/s]



71it [00:31,  2.70it/s]



72it [00:32,  2.60it/s]



73it [00:32,  2.55it/s]



74it [00:33,  2.50it/s]



75it [00:33,  2.49it/s]



76it [00:34,  2.49it/s]



77it [00:34,  2.47it/s]



78it [00:34,  2.47it/s]



79it [00:35,  2.55it/s]



80it [00:35,  2.65it/s]



81it [00:35,  2.58it/s]



82it [00:36,  2.53it/s]



83it [00:36,  2.49it/s]



84it [00:37,  2.61it/s]



85it [00:37,  2.52it/s]



86it [00:38,  2.46it/s]



87it [00:38,  2.42it/s]



88it [00:38,  2.53it/s]



89it [00:39,  2.63it/s]



90it [00:39,  2.44it/s]



91it [00:40,  2.41it/s]



92it [00:40,  2.38it/s]



93it [00:40,  2.35it/s]



94it [00:41,  2.34it/s]



95it [00:41,  2.34it/s]



96it [00:42,  2.33it/s]



97it [00:42,  2.35it/s]



98it [00:42,  2.49it/s]



99it [00:43,  2.51it/s]



100it [00:43,  2.47it/s]



101it [00:44,  2.44it/s]



102it [00:44,  2.41it/s]



103it [00:45,  2.39it/s]



104it [00:45,  1.99it/s]



105it [00:46,  2.11it/s]



106it [00:46,  2.20it/s]



107it [00:46,  2.25it/s]



108it [00:47,  2.41it/s]



109it [00:47,  2.44it/s]



110it [00:48,  2.44it/s]



111it [00:48,  2.51it/s]



112it [00:48,  2.58it/s]



113it [00:49,  2.63it/s]



114it [00:49,  2.69it/s]



115it [00:49,  2.71it/s]



116it [00:50,  2.73it/s]



117it [00:50,  2.77it/s]



118it [00:51,  2.71it/s]



119it [00:51,  2.78it/s]



120it [00:51,  2.70it/s]



121it [00:52,  2.72it/s]



122it [00:52,  2.57it/s]



123it [00:53,  2.45it/s]



124it [00:53,  2.40it/s]



125it [00:53,  2.36it/s]



126it [00:54,  2.31it/s]



127it [00:54,  2.30it/s]



128it [00:55,  2.31it/s]



129it [00:55,  2.46it/s]



130it [00:56,  2.44it/s]



131it [00:56,  2.42it/s]



132it [00:56,  2.40it/s]



133it [00:57,  2.40it/s]



134it [00:57,  2.37it/s]



135it [00:58,  2.46it/s]



136it [00:58,  2.43it/s]



137it [00:58,  2.40it/s]



138it [00:59,  2.39it/s]



139it [00:59,  2.36it/s]



140it [01:00,  2.33it/s]



141it [01:00,  2.32it/s]



142it [01:01,  2.29it/s]



143it [01:01,  2.32it/s]



144it [01:01,  2.33it/s]



145it [01:02,  2.39it/s]



146it [01:02,  2.48it/s]



147it [01:03,  2.53it/s]



148it [01:03,  2.58it/s]



149it [01:03,  2.66it/s]



150it [01:04,  2.55it/s]



151it [01:04,  2.48it/s]



152it [01:05,  2.48it/s]



153it [01:05,  2.43it/s]



154it [01:05,  2.41it/s]



155it [01:06,  2.37it/s]



156it [01:06,  2.32it/s]



157it [01:07,  2.30it/s]



158it [01:07,  2.28it/s]



159it [01:08,  2.25it/s]



160it [01:08,  2.27it/s]



161it [01:09,  2.30it/s]



162it [01:09,  2.28it/s]



163it [01:09,  2.33it/s]



164it [01:10,  2.32it/s]



165it [01:10,  2.33it/s]



166it [01:11,  2.43it/s]



167it [01:11,  2.56it/s]



168it [01:11,  2.52it/s]



169it [01:12,  2.47it/s]



170it [01:12,  2.46it/s]



171it [01:13,  2.56it/s]



172it [01:13,  2.50it/s]



173it [01:13,  2.61it/s]



174it [01:14,  2.51it/s]



175it [01:14,  2.48it/s]



176it [01:15,  2.45it/s]



177it [01:15,  2.41it/s]



178it [01:15,  2.41it/s]



179it [01:16,  2.49it/s]



180it [01:16,  2.59it/s]



181it [01:17,  2.65it/s]



182it [01:17,  2.53it/s]



183it [01:17,  2.47it/s]



184it [01:18,  2.41it/s]



185it [01:18,  2.32it/s]



186it [01:19,  2.32it/s]



187it [01:19,  2.32it/s]



188it [01:20,  2.32it/s]



189it [01:20,  2.28it/s]



190it [01:20,  2.30it/s]



191it [01:21,  2.29it/s]



192it [01:21,  2.42it/s]



193it [01:22,  2.41it/s]



194it [01:22,  2.50it/s]



195it [01:22,  2.58it/s]



196it [01:23,  2.56it/s]



197it [01:23,  2.49it/s]



198it [01:24,  2.45it/s]



199it [01:24,  2.42it/s]



200it [01:24,  2.41it/s]



201it [01:25,  2.38it/s]



202it [01:25,  2.38it/s]



203it [01:26,  2.38it/s]



204it [01:26,  2.36it/s]



205it [01:27,  2.35it/s]



206it [01:27,  2.41it/s]



207it [01:28,  1.99it/s]



208it [01:28,  2.15it/s]



209it [01:29,  2.23it/s]



210it [01:29,  2.22it/s]



211it [01:29,  2.23it/s]



212it [01:30,  2.23it/s]



213it [01:30,  2.22it/s]



214it [01:31,  2.20it/s]



215it [01:31,  2.11it/s]



216it [01:32,  2.03it/s]



217it [01:32,  2.08it/s]



218it [01:33,  2.11it/s]



219it [01:33,  2.15it/s]



220it [01:34,  2.21it/s]



221it [01:34,  2.23it/s]



222it [01:34,  2.23it/s]



223it [01:35,  2.26it/s]



224it [01:35,  2.29it/s]



225it [01:36,  2.31it/s]



226it [01:36,  2.32it/s]



227it [01:37,  2.32it/s]



228it [01:37,  2.32it/s]



229it [01:37,  2.34it/s]



230it [01:38,  2.32it/s]



231it [01:38,  2.40it/s]



232it [01:39,  2.50it/s]



233it [01:39,  2.44it/s]



234it [01:40,  2.42it/s]



235it [01:40,  2.54it/s]



236it [01:40,  2.49it/s]



237it [01:41,  2.46it/s]



238it [01:41,  2.51it/s]



239it [01:41,  2.61it/s]



240it [01:42,  2.57it/s]



241it [01:42,  2.52it/s]



242it [01:43,  2.49it/s]



243it [01:43,  2.58it/s]



244it [01:43,  2.52it/s]



245it [01:44,  2.59it/s]



246it [01:44,  2.61it/s]



247it [01:45,  2.49it/s]



248it [01:45,  2.42it/s]



249it [01:45,  2.39it/s]



250it [01:46,  2.36it/s]



251it [01:46,  2.35it/s]



252it [01:47,  2.34it/s]



253it [01:47,  2.31it/s]



254it [01:48,  2.35it/s]



255it [01:48,  2.35it/s]



256it [01:48,  2.36it/s]



257it [01:49,  2.48it/s]



258it [01:49,  2.55it/s]



259it [01:50,  2.60it/s]



260it [01:50,  2.53it/s]



261it [01:50,  2.63it/s]



262it [01:51,  2.56it/s]



263it [01:51,  2.48it/s]



264it [01:52,  2.47it/s]



265it [01:52,  2.46it/s]



266it [01:52,  2.58it/s]



267it [01:53,  2.67it/s]



268it [01:53,  2.74it/s]



269it [01:53,  2.59it/s]



270it [01:54,  2.64it/s]



271it [01:54,  2.56it/s]



272it [01:55,  2.62it/s]



273it [01:55,  2.70it/s]



274it [01:55,  2.62it/s]



275it [01:56,  2.69it/s]



276it [01:56,  2.76it/s]



277it [01:56,  2.75it/s]



278it [01:57,  2.63it/s]



279it [01:57,  2.50it/s]



280it [01:58,  2.42it/s]



281it [01:58,  2.36it/s]



282it [01:59,  2.31it/s]



283it [01:59,  2.30it/s]



284it [01:59,  2.31it/s]



285it [02:00,  2.29it/s]



286it [02:00,  2.30it/s]



287it [02:01,  2.34it/s]



288it [02:01,  2.33it/s]



289it [02:02,  2.33it/s]



290it [02:02,  2.46it/s]



291it [02:02,  2.60it/s]



292it [02:03,  2.66it/s]



293it [02:03,  2.57it/s]



294it [02:04,  2.50it/s]



295it [02:04,  2.46it/s]



296it [02:04,  2.45it/s]



297it [02:05,  2.52it/s]



298it [02:05,  2.62it/s]



299it [02:05,  2.56it/s]



300it [02:06,  2.61it/s]



301it [02:06,  2.56it/s]



302it [02:07,  2.53it/s]



303it [02:07,  2.52it/s]



304it [02:07,  2.63it/s]



305it [02:08,  2.56it/s]



306it [02:08,  2.51it/s]



307it [02:09,  2.09it/s]



308it [02:09,  2.18it/s]



309it [02:10,  2.31it/s]



310it [02:10,  2.40it/s]



311it [02:10,  2.41it/s]



312it [02:11,  2.37it/s]


# Loss Curve

In [18]:
try:
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    plt.figure(figsize=(8, 6))
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Attention MAE')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.tight_layout()
    plt.savefig(os.path.join(model_path, timestamp + ".png"))
    plt.close()
except:
    print("Model did not finish training")


# Metrics

In [19]:
model.evaluate(test_batches)




[0.01738385483622551, 0.07014688104391098]