<a href="https://colab.research.google.com/github/atick-faisal/TAVI/blob/main/src/colab/MultiViewUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab Info

In [1]:
from psutil import virtual_memory

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)


ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Mon Jan 16 08:16:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Google Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Download Dataset

In [3]:
!pip install -U --no-cache-dir gdown --pre
!pip install image-similarity-measures > /dev/null

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gdown
  Downloading gdown-4.6.0-py3-none-any.whl (14 kB)
Installing collected packages: gdown
  Attempting uninstall: gdown
    Found existing installation: gdown 4.4.0
    Uninstalling gdown-4.4.0:
      Successfully uninstalled gdown-4.4.0
Successfully installed gdown-4.6.0


In [11]:
!gdown "12CxxS2w7mprCnVX9Y_Kiti6vEPbTQW2l"
!gdown "1GfS5qapaM5Ci7ykHQW_4B-kWqk_tmvSd"
!unzip -o "TAWSS2ECAP_v1.zip" > /dev/null


Downloading...
From: https://drive.google.com/uc?id=12CxxS2w7mprCnVX9Y_Kiti6vEPbTQW2l
To: /content/UNet_2DCNN.py
100% 58.6k/58.6k [00:00<00:00, 59.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1GfS5qapaM5Ci7ykHQW_4B-kWqk_tmvSd
To: /content/TAWSS2ECAP_v1.zip
100% 1.06G/1.06G [00:06<00:00, 163MB/s]


# Imports

In [12]:
import os
import cv2
import random
import datetime
import matplotlib
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm
from numpy.typing import NDArray
from UNet_2DCNN import UNet as UNetSakib
# from image_similarity_measures.quality_metrics import *

matplotlib.use('Agg')
plt.rcParams["font.size"] = 16


# Config

In [13]:
DATASET         = "TAWSS2ECAP"
MODEL_NAME      = "MultiViewUNet"
DATASET_PATH    = "/content/Images/"
TRAIN_DIR       = "Train/"
TEST_DIR        = "Test/"
INPUT_DIR       = "Input/"
TARGET_DIR      = "Target/"
MODEL_PATH      = "/content/drive/MyDrive/Research/TAVI/Models/"
PRED_PATH       = "/content/drive/MyDrive/Research/TAVI/Predictions/"

IMG_SIZE        = 256
BATCH_SIZE      = 16
BUFFER_SIZE     = 1000
VAL_SPLIT       = 0.2
LEARNING_RATE   = 0.001
N_EPOCHS        = 300
PATIENCE        = 30

EXP_NAME        = f"{MODEL_NAME}_I{IMG_SIZE}_B{BATCH_SIZE}_LR{LEARNING_RATE}_{DATASET}"


# Architecture

In [14]:
class UNet:
    def __init__(
        self,
        img_size: int,
        n_channels: int = 3,
        width: int = 32,
        depth: int = 4,
        kernel_size: int = 3
    ):
        self.img_size = img_size
        self.n_channels = n_channels
        self.width = width
        self.depth = depth
        self.kernel_size = kernel_size

    @staticmethod
    def conv(
        x: tf.Tensor,
        filters: int,
        kernel_size: int
    ) -> tf.Tensor:
        for i in range(2):
            x = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                data_format="channels_last",
                dilation_rate=1,
                groups=1,
                activation=None,
                use_bias=True,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros"
            )(x)

            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def deconv(x: tf.Tensor, filters: int) -> tf.Tensor:
        x = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=2,
            strides=2,
            padding="same",
            output_padding=None,
            data_format=None,
            dilation_rate=1,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros"
        )(x)

        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def output(x: tf.Tensor) -> tf.Tensor:
        return tf.keras.layers.Conv2D(3, (1, 1), activation="sigmoid")(x)

    @staticmethod
    def pool(x: tf.Tensor, pool_size: int = 2) -> tf.Tensor:
        return tf.keras.layers.MaxPool2D(pool_size)(x)

    def __call__(self) -> tf.keras.Model:
        inputs = tf.keras.layers.Input(
            shape=(self.img_size, self.img_size, self.n_channels)
        )

        # scaled = tf.keras.layers.Rescaling(1./255.0, offset=0)(inputs)

        # ------------------ Downsampling ---------------------
        downsample_layers = []
        downsample_layers.append(
            self.conv(
                x=inputs,
                filters=self.width,
                kernel_size=self.kernel_size
            )
        )
        for i in range(1, self.depth):
            filters = int((2 ** i) * self.width)
            downsample_layers.append(
                self.pool(
                    self.conv(
                        x=downsample_layers[i - 1],
                        filters=filters,
                        kernel_size=self.kernel_size
                    )
                )
            )

        # ------------------- Features --------------------
        n_features = int((2 ** self.depth) * self.width)
        self.features = self.pool(
            self.conv(
                x=downsample_layers[-1],
                filters=n_features,
                kernel_size=self.kernel_size
            )
        )

        # ------------------- Upsampling --------------------
        upsample_layers = []
        upsample_layers.append(self.features)
        for i in range(1, self.depth + 1):
            filters = int((2 ** (self.depth - i)) * self.width)
            upsample_layers.append(
                self.conv(
                    x=tf.keras.layers.concatenate([
                        downsample_layers[self.depth - i],
                        self.deconv(
                            x=upsample_layers[i - 1],
                            filters=filters
                        )
                    ]),
                    filters=filters,
                    kernel_size=self.kernel_size
                )
            )

        # ---------------------- Output -----------------------
        outputs = self.output(upsample_layers[-1])

        return tf.keras.Model(inputs, outputs)


# Loss Functions

In [15]:
def attention_mse(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.square(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


def attention_mae(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.abs(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


# Utils

In [16]:
def load_data_from_dir(path: str) -> tf.data.Dataset:
    return tf.keras.utils.image_dataset_from_directory(
        directory=path,
        labels=None,
        color_mode='rgb',
        batch_size=BATCH_SIZE,
        image_size=(IMG_SIZE, IMG_SIZE),
        shuffle=True,
        seed=42,
        # validation_split=VAL_SPLIT,
        # subset=subset,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False
    )


# Dataloaders

In [17]:
trainX = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, INPUT_DIR))
trainY = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, TARGET_DIR))
testX = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, INPUT_DIR))
testY = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, TARGET_DIR))

train_ds = tf.data.Dataset.zip((trainX, trainY))
test_ds = tf.data.Dataset.zip((testX, testY))

print(train_ds.element_spec)
print(test_ds.element_spec)


Found 8208 files belonging to 1 classes.
Found 8208 files belonging to 1 classes.
Found 312 files belonging to 1 classes.
Found 312 files belonging to 1 classes.
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))


# Normalization

In [18]:
# def normalize(input_image, output_image):
#     input_image = tf.cast(input_image, tf.float32) / 255.0
#     output_image = tf.cast(input_image, tf.float32) / 255.0
#     return input_image, output_image

In [19]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), normalization_layer(y)))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), normalization_layer(y)))

# Augmentation

In [20]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomZoom((-0.1, -0.7), seed=seed)
        self.augment_labels = tf.keras.layers.RandomZoom((-0.1, -0.7), seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels

# Optimization

In [21]:
AUTOTUNE = tf.data.AUTOTUNE

train_batches = (
    train_ds
    .cache()
    # .shuffle(BUFFER_SIZE)
    # .batch(BATCH_SIZE)
    # .repeat()
    # .map(Augment())
    .prefetch(buffer_size=AUTOTUNE)
)

test_batches = (test_ds)


# Training Config

In [22]:
model_path = os.path.join(MODEL_PATH, EXP_NAME)

In [23]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        model_path,
        monitor = 'val_loss',
        verbose = 1,
        save_best_only = True,
        save_weights_only = True
    )
]

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE
)

model = UNet(IMG_SIZE)()
# model = MobileUNet(IMG_SIZE)()
# model = UNetSakib(
#     length=IMG_SIZE, 
#     width=IMG_SIZE, 
#     model_depth=4, 
#     num_channel=1, 
#     model_width=32, 
#     kernel_size=3, 
#     problem_type='Regression',
#     output_nums=1, 
#     ds=0, 
#     ae=0, 
#     ag=0, 
#     lstm=0, 
#     alpha=1, 
#     feature_number=1024, 
#     is_transconv=True
# ).FSC_Net()

model.compile(
    loss=attention_mse,
    optimizer=optimizer,
    metrics=[attention_mae]
)

try:
    model.load_weights(model_path)
except:
    print("Checkpoint not found")
    pass


Checkpoint not found


# Training

In [24]:
history = model.fit(
    train_batches,
    validation_data=test_batches,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/300
Epoch 1: val_loss improved from inf to 0.06334, saving model to /content/drive/MyDrive/Research/TAVI/Models/MultiViewUNet_I256_B16_LR0.001_TAWSS2ECAP
Epoch 2/300
Epoch 2: val_loss improved from 0.06334 to 0.03187, saving model to /content/drive/MyDrive/Research/TAVI/Models/MultiViewUNet_I256_B16_LR0.001_TAWSS2ECAP
Epoch 3/300
Epoch 3: val_loss improved from 0.03187 to 0.02854, saving model to /content/drive/MyDrive/Research/TAVI/Models/MultiViewUNet_I256_B16_LR0.001_TAWSS2ECAP
Epoch 4/300
Epoch 4: val_loss improved from 0.02854 to 0.02537, saving model to /content/drive/MyDrive/Research/TAVI/Models/MultiViewUNet_I256_B16_LR0.001_TAWSS2ECAP
Epoch 5/300
Epoch 5: val_loss improved from 0.02537 to 0.02467, saving model to /content/drive/MyDrive/Research/TAVI/Models/MultiViewUNet_I256_B16_LR0.001_TAWSS2ECAP
Epoch 6/300
Epoch 6: val_loss improved from 0.02467 to 0.02404, saving model to /content/drive/MyDrive/Research/TAVI/Models/MultiViewUNet_I256_B16_LR0.001_TAWSS2ECAP
Epoch 7/

# Save Model

In [25]:
timestamp = datetime.datetime.now().strftime('%b-%d-%I:%M%p')
if not os.path.exists(model_path):
    os.makedirs(model_path)

model.save(os.path.join(model_path, timestamp))



# Save Predictions

In [26]:
test_ds_unbatched = test_batches.unbatch()

pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

metrics = pd.DataFrame()

for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

    # if not idx % 10 == 0:
    #     continue

    target = tf.squeeze(target).numpy()
    prediction = tf.squeeze(
        model.predict(
            tf.expand_dims(input, axis=0)
            )
        ).numpy()
    prediction[target == 1.0] = 1.0

    # _metrics = pd.DataFrame([[
    #         rmse(target, prediction),
    #         psnr(target, prediction),
    #         ssim(target, prediction),
    #         fsim(target, prediction),
    #         issm(target, prediction),
    #         sre(target, prediction),
    #         sam(target, prediction),
    #         uiq(target, prediction),
    #     ]],
    #     columns=["RMSE", "PSNR", "SSIM", "FSIM", "ISSM", "SRE", "SAM", "UIQ"]
    # )

    # metrics = pd.concat([metrics, _metrics], axis=0)

    plt.figure(figsize=(7, 7))
    plt.imshow(target)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_T.png"))
    plt.close()

    plt.figure(figsize=(7, 7))
    plt.imshow(prediction)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_P.png"))
    plt.close()


    # plt.figure(figsize=(12, 4))
    # plt.subplot(1, 3, 1)
    # plt.imshow(input)
    # plt.axis("off")
    # plt.title("INPUT")
    # plt.subplot(1, 3, 2)
    # plt.imshow(target)
    # plt.axis("off")
    # plt.title("TARGET")
    # plt.subplot(1, 3, 3)
    # plt.imshow(prediction)
    # plt.axis("off")
    # plt.title("PREDICTION")
    # plt.tight_layout()
    # plt.savefig(os.path.join(pred_path, f"{idx}.png"))
    # plt.close()  


0it [00:00, ?it/s]



1it [00:03,  3.27s/it]



2it [00:03,  1.52s/it]



3it [00:03,  1.05it/s]



4it [00:04,  1.42it/s]



5it [00:04,  1.79it/s]



6it [00:04,  2.17it/s]



7it [00:05,  2.40it/s]



8it [00:05,  2.68it/s]



9it [00:05,  2.80it/s]



10it [00:06,  2.86it/s]



11it [00:06,  2.90it/s]



12it [00:06,  2.95it/s]



13it [00:07,  2.95it/s]



14it [00:07,  3.14it/s]



15it [00:07,  3.22it/s]



16it [00:07,  3.01it/s]



17it [00:08,  2.87it/s]



18it [00:08,  2.90it/s]



19it [00:09,  2.91it/s]



20it [00:09,  2.88it/s]



21it [00:09,  2.97it/s]



22it [00:09,  3.08it/s]



23it [00:10,  2.99it/s]



24it [00:10,  2.93it/s]



25it [00:11,  2.91it/s]



26it [00:11,  2.95it/s]



27it [00:11,  3.04it/s]



28it [00:12,  3.03it/s]



29it [00:12,  3.17it/s]



30it [00:12,  3.16it/s]



31it [00:12,  3.32it/s]



32it [00:13,  2.86it/s]



33it [00:13,  2.87it/s]



34it [00:13,  3.07it/s]



35it [00:14,  3.18it/s]



36it [00:14,  3.15it/s]



37it [00:14,  3.13it/s]



38it [00:15,  3.10it/s]



39it [00:15,  3.27it/s]



40it [00:15,  3.21it/s]



41it [00:16,  3.20it/s]



42it [00:16,  3.15it/s]



43it [00:16,  3.14it/s]



44it [00:17,  3.07it/s]



45it [00:17,  3.06it/s]



46it [00:17,  3.03it/s]



47it [00:18,  3.06it/s]



48it [00:18,  3.23it/s]



49it [00:18,  2.95it/s]



50it [00:19,  2.99it/s]



51it [00:19,  2.99it/s]



52it [00:19,  3.11it/s]



53it [00:20,  3.17it/s]



54it [00:20,  3.30it/s]



55it [00:20,  3.23it/s]



56it [00:20,  3.38it/s]



57it [00:21,  3.27it/s]



58it [00:21,  3.19it/s]



59it [00:21,  3.28it/s]



60it [00:22,  3.24it/s]



61it [00:22,  3.18it/s]



62it [00:22,  3.14it/s]



63it [00:23,  2.65it/s]



64it [00:23,  2.75it/s]



65it [00:24,  2.66it/s]



66it [00:24,  2.88it/s]



67it [00:24,  3.09it/s]



68it [00:24,  3.10it/s]



69it [00:25,  3.11it/s]



70it [00:25,  3.09it/s]



71it [00:25,  3.03it/s]



72it [00:26,  3.04it/s]



73it [00:26,  3.05it/s]



74it [00:26,  3.18it/s]



75it [00:27,  3.12it/s]



76it [00:27,  3.11it/s]



77it [00:27,  3.07it/s]



78it [00:28,  3.05it/s]



79it [00:28,  3.05it/s]



80it [00:28,  3.06it/s]



81it [00:29,  3.01it/s]



82it [00:29,  3.03it/s]



83it [00:29,  3.13it/s]



84it [00:30,  3.09it/s]



85it [00:30,  3.07it/s]



86it [00:30,  3.11it/s]



87it [00:31,  3.11it/s]



88it [00:31,  3.09it/s]



89it [00:31,  3.07it/s]



90it [00:32,  3.07it/s]



91it [00:32,  3.06it/s]



92it [00:32,  3.07it/s]



93it [00:33,  3.06it/s]



94it [00:33,  3.05it/s]



95it [00:33,  2.72it/s]



96it [00:34,  2.85it/s]



97it [00:34,  2.83it/s]



98it [00:34,  2.89it/s]



99it [00:35,  3.10it/s]



100it [00:35,  3.09it/s]



101it [00:35,  3.15it/s]



102it [00:36,  3.27it/s]



103it [00:36,  3.39it/s]



104it [00:36,  3.46it/s]



105it [00:36,  3.43it/s]



106it [00:37,  3.31it/s]



107it [00:37,  3.42it/s]



108it [00:37,  3.48it/s]



109it [00:38,  3.45it/s]



110it [00:38,  3.21it/s]



111it [00:38,  3.15it/s]



112it [00:39,  3.09it/s]



113it [00:39,  3.06it/s]



114it [00:39,  3.05it/s]



115it [00:40,  3.08it/s]



116it [00:40,  3.10it/s]



117it [00:40,  3.07it/s]



118it [00:40,  3.25it/s]



119it [00:41,  3.36it/s]



120it [00:41,  3.47it/s]



121it [00:41,  3.54it/s]



122it [00:42,  3.55it/s]



123it [00:42,  3.37it/s]



124it [00:42,  3.22it/s]



125it [00:43,  3.16it/s]



126it [00:43,  3.15it/s]



127it [00:43,  2.69it/s]



128it [00:44,  2.80it/s]



129it [00:44,  2.87it/s]



130it [00:44,  2.89it/s]



131it [00:45,  3.04it/s]



132it [00:45,  3.01it/s]



133it [00:45,  3.03it/s]



134it [00:46,  3.07it/s]



135it [00:46,  3.23it/s]



136it [00:46,  3.17it/s]



137it [00:47,  3.03it/s]



138it [00:47,  2.97it/s]



139it [00:47,  3.13it/s]



140it [00:48,  3.23it/s]



141it [00:48,  3.33it/s]



142it [00:48,  3.24it/s]



143it [00:48,  3.38it/s]



144it [00:49,  3.45it/s]



145it [00:49,  3.32it/s]



146it [00:49,  3.24it/s]



147it [00:50,  3.18it/s]



148it [00:50,  3.15it/s]



149it [00:50,  3.12it/s]



150it [00:51,  3.07it/s]



151it [00:51,  3.07it/s]



152it [00:51,  3.07it/s]



153it [00:52,  3.08it/s]



154it [00:52,  3.25it/s]



155it [00:52,  3.22it/s]



156it [00:52,  3.29it/s]



157it [00:53,  3.36it/s]



158it [00:53,  3.24it/s]



159it [00:54,  2.65it/s]



160it [00:54,  2.77it/s]



161it [00:54,  2.85it/s]



162it [00:55,  3.02it/s]



163it [00:55,  3.20it/s]



164it [00:55,  3.33it/s]



165it [00:55,  3.44it/s]



166it [00:56,  3.50it/s]



167it [00:56,  3.51it/s]



168it [00:56,  3.58it/s]



169it [00:57,  3.39it/s]



170it [00:57,  3.38it/s]



171it [00:57,  3.27it/s]



172it [00:57,  3.24it/s]



173it [00:58,  3.17it/s]



174it [00:58,  3.07it/s]



175it [00:59,  3.02it/s]



176it [00:59,  3.03it/s]



177it [00:59,  3.02it/s]



178it [00:59,  3.03it/s]



179it [01:00,  3.07it/s]



180it [01:00,  3.05it/s]



181it [01:00,  3.05it/s]



182it [01:01,  3.24it/s]



183it [01:01,  3.33it/s]



184it [01:01,  3.32it/s]



185it [01:02,  3.24it/s]



186it [01:02,  3.35it/s]



187it [01:02,  3.26it/s]



188it [01:03,  3.27it/s]



189it [01:03,  3.21it/s]



190it [01:03,  3.25it/s]



191it [01:04,  2.69it/s]



192it [01:04,  2.82it/s]



193it [01:04,  2.88it/s]



194it [01:05,  2.93it/s]



195it [01:05,  2.96it/s]



196it [01:05,  3.00it/s]



197it [01:06,  3.15it/s]



198it [01:06,  3.13it/s]



199it [01:06,  3.29it/s]



200it [01:06,  3.37it/s]



201it [01:07,  3.37it/s]



202it [01:07,  3.44it/s]



203it [01:07,  3.32it/s]



204it [01:08,  3.41it/s]



205it [01:08,  3.19it/s]



206it [01:08,  3.13it/s]



207it [01:09,  3.21it/s]



208it [01:09,  3.13it/s]



209it [01:09,  3.10it/s]



210it [01:10,  3.04it/s]



211it [01:10,  3.17it/s]



212it [01:10,  3.18it/s]



213it [01:11,  3.27it/s]



214it [01:11,  3.18it/s]



215it [01:11,  3.09it/s]



216it [01:12,  3.07it/s]



217it [01:12,  3.05it/s]



218it [01:12,  3.01it/s]



219it [01:13,  3.02it/s]



220it [01:13,  3.03it/s]



221it [01:13,  3.18it/s]



222it [01:13,  3.13it/s]



223it [01:14,  2.62it/s]



224it [01:14,  2.74it/s]



225it [01:15,  2.87it/s]



226it [01:15,  2.89it/s]



227it [01:15,  2.96it/s]



228it [01:16,  3.04it/s]



229it [01:16,  3.05it/s]



230it [01:16,  3.07it/s]



231it [01:17,  3.23it/s]



232it [01:17,  3.18it/s]



233it [01:17,  3.34it/s]



234it [01:17,  3.23it/s]



235it [01:18,  3.27it/s]



236it [01:18,  3.39it/s]



237it [01:18,  3.44it/s]



238it [01:19,  3.31it/s]



239it [01:19,  3.24it/s]



240it [01:19,  3.19it/s]



241it [01:20,  3.20it/s]



242it [01:20,  3.13it/s]



243it [01:20,  3.12it/s]



244it [01:21,  3.26it/s]



245it [01:21,  3.18it/s]



246it [01:21,  3.13it/s]



247it [01:22,  3.11it/s]



248it [01:22,  3.08it/s]



249it [01:22,  3.23it/s]



250it [01:22,  3.15it/s]



251it [01:23,  3.12it/s]



252it [01:23,  3.07it/s]



253it [01:23,  3.05it/s]



254it [01:24,  3.05it/s]



255it [01:24,  2.66it/s]



256it [01:25,  2.93it/s]



257it [01:25,  3.03it/s]



258it [01:25,  3.18it/s]



259it [01:25,  3.24it/s]



260it [01:26,  3.13it/s]



261it [01:26,  3.25it/s]



262it [01:26,  3.19it/s]



263it [01:27,  3.10it/s]



264it [01:27,  3.09it/s]



265it [01:27,  3.09it/s]



266it [01:28,  3.11it/s]



267it [01:28,  3.26it/s]



268it [01:28,  3.14it/s]



269it [01:29,  3.20it/s]



270it [01:29,  3.15it/s]



271it [01:29,  3.13it/s]



272it [01:30,  3.24it/s]



273it [01:30,  3.30it/s]



274it [01:30,  3.22it/s]



275it [01:30,  3.17it/s]



276it [01:31,  3.14it/s]



277it [01:31,  3.21it/s]



278it [01:31,  3.37it/s]



279it [01:32,  3.23it/s]



280it [01:32,  3.19it/s]



281it [01:32,  3.15it/s]



282it [01:33,  3.26it/s]



283it [01:33,  3.37it/s]



284it [01:33,  3.27it/s]



285it [01:34,  3.19it/s]



286it [01:34,  3.10it/s]



287it [01:34,  2.61it/s]



288it [01:35,  2.83it/s]



289it [01:35,  2.88it/s]



290it [01:35,  2.94it/s]



291it [01:36,  2.97it/s]



292it [01:36,  3.14it/s]



293it [01:36,  3.29it/s]



294it [01:37,  3.26it/s]



295it [01:37,  3.35it/s]



296it [01:37,  3.22it/s]



297it [01:38,  3.14it/s]



298it [01:38,  3.20it/s]



299it [01:38,  3.29it/s]



300it [01:38,  3.12it/s]



301it [01:39,  3.09it/s]



302it [01:39,  3.02it/s]



303it [01:39,  2.99it/s]



304it [01:40,  3.05it/s]



305it [01:40,  3.03it/s]



306it [01:40,  3.05it/s]



307it [01:41,  3.09it/s]



308it [01:41,  3.24it/s]



309it [01:41,  3.17it/s]



310it [01:42,  3.07it/s]



311it [01:42,  3.06it/s]



312it [01:42,  3.03it/s]


In [27]:
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 6))
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Attention MAE')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.tight_layout()
plt.savefig(os.path.join(model_path, timestamp + ".png"))
plt.close()

In [28]:
model.evaluate(test_batches)



[0.01842138171195984, 0.06518226116895676]