<a href="https://colab.research.google.com/github/atick-faisal/TAVI/blob/main/src/training/MultiViewUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab Info

In [1]:
from psutil import virtual_memory

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)


ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')


Tue Mar 28 10:46:00 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P8    11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Fix for GDrive

In [2]:
!pip install -U --no-cache-dir gdown --pre > /dev/null


# Mount GDrive

In [3]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Download and Extract Dataset

In [4]:
!gdown "1XjSQgkT7XVZTxNSPtK3MbmPqrX8tLY3a"
!unzip -o "AAA_DATASET_v1.zip" > /dev/null


Downloading...
From (uriginal): https://drive.google.com/uc?id=1XjSQgkT7XVZTxNSPtK3MbmPqrX8tLY3a
From (redirected): https://drive.google.com/uc?id=1XjSQgkT7XVZTxNSPtK3MbmPqrX8tLY3a&confirm=t&uuid=039145d7-876b-44a1-9df3-048a4580bdbb
To: /content/AAA_DATASET_v1.zip
100% 3.16G/3.16G [00:17<00:00, 180MB/s]


# Imports

In [5]:
import os
import datetime
import matplotlib
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm

matplotlib.use('Agg')
plt.rcParams["font.size"] = 16

# Config

In [6]:
PROBLEM = "TAWSS_2_ECAP"

MODEL_NAME = "MultiViewUNet"
DATASET_PATH = "/content/Images/"
TRAIN_DIR = "Train/"
TEST_DIR = "Test/"
INPUT_DIR = PROBLEM.split("_2_")[0]
TARGET_DIR = PROBLEM.split("_2_")[1]
MODEL_PATH = "/content/drive/MyDrive/Research/TAVI/Models/"
PRED_PATH = "/content/drive/MyDrive/Research/TAVI/Predictions/"
IMG_SIZE = 256
BATCH_SIZE = 16
BUFFER_SIZE = 1000
VAL_SPLIT = 0.2
LEARNING_RATE = 0.001
N_EPOCHS = 300
PATIENCE = 30

EXP_NAME = f"{PROBLEM}_{MODEL_NAME}_I{IMG_SIZE}_B{BATCH_SIZE}_LR{LEARNING_RATE}"


# Architecture

In [7]:
class UNet:
    def __init__(
        self,
        img_size: int,
        n_channels: int = 3,
        width: int = 32,
        depth: int = 4,
        kernel_size: int = 3
    ):
        self.img_size = img_size
        self.n_channels = n_channels
        self.width = width
        self.depth = depth
        self.kernel_size = kernel_size

    @staticmethod
    def conv(
        x: tf.Tensor,
        filters: int,
        kernel_size: int
    ) -> tf.Tensor:
        for i in range(2):
            x = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                data_format="channels_last",
                dilation_rate=1,
                groups=1,
                activation=None,
                use_bias=True,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros",
            )(x)

            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def deconv(x: tf.Tensor, filters: int) -> tf.Tensor:
        x = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=2,
            strides=2,
            padding="same",
            output_padding=None,
            data_format=None,
            dilation_rate=1,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
        )(x)

        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def output(x: tf.Tensor) -> tf.Tensor:
        return tf.keras.layers.Conv2D(3, (1, 1), activation="sigmoid")(x)

    @staticmethod
    def pool(x: tf.Tensor, pool_size: int = 2) -> tf.Tensor:
        return tf.keras.layers.MaxPool2D(pool_size)(x)

    @staticmethod
    def dropout(x: tf.Tensor, amount: float = 0.5) -> tf.Tensor:
        return tf.keras.layers.Dropout(amount)(x)

    def __call__(self) -> tf.keras.Model:
        inputs = tf.keras.layers.Input(
            shape=(self.img_size, self.img_size, self.n_channels)
        )

        # scaled = tf.keras.layers.Rescaling(1./255.0, offset=0)(inputs)

        # ------------------ Downsampling ---------------------
        downsample_layers = []
        downsample_layers.append(
            self.conv(
                x=inputs,
                filters=self.width,
                kernel_size=self.kernel_size
            )
        )
        for i in range(1, self.depth):
            dropout_amount = 0.2 if i == 1 else 0.5
            filters = int((2 ** i) * self.width)
            downsample_layers.append(
                self.dropout(
                    self.pool(
                        self.conv(
                            x=downsample_layers[i - 1],
                            filters=filters,
                            kernel_size=self.kernel_size
                        )
                    ),
                    amount=dropout_amount
                )
            )

        # ------------------- Features --------------------
        n_features = int((2 ** self.depth) * self.width)
        self.features = self.pool(
            self.conv(
                x=downsample_layers[-1],
                filters=n_features,
                kernel_size=self.kernel_size
            )
        )

        # ------------------- Upsampling --------------------
        upsample_layers = []
        upsample_layers.append(self.features)
        for i in range(1, self.depth + 1):
            filters = int((2 ** (self.depth - i)) * self.width)
            upsample_layers.append(
                self.conv(
                    x=self.dropout(
                        tf.keras.layers.concatenate([
                            downsample_layers[self.depth - i],
                            self.deconv(
                                x=upsample_layers[i - 1],
                                filters=filters
                            )
                        ])
                    ),
                    filters=filters,
                    kernel_size=self.kernel_size
                )
            )

        # ---------------------- Output -----------------------
        outputs = self.output(upsample_layers[-1])

        return tf.keras.Model(inputs, outputs)


# Loss Functions / Metrics

In [8]:
def attention_mse(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.square(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


def attention_mae(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.abs(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


# DataLoader

In [9]:
def load_data_from_dir(path: str) -> tf.data.Dataset:
    return tf.keras.utils.image_dataset_from_directory(
        directory=path,
        labels=None,
        color_mode='rgb',
        batch_size=BATCH_SIZE,
        image_size=(IMG_SIZE, IMG_SIZE),
        shuffle=False,
        seed=42,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False
    )


# Load Dataset

In [10]:
trainX = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, INPUT_DIR))
trainY = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, TARGET_DIR))
testX = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, INPUT_DIR))
testY = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, TARGET_DIR))

train_ds = tf.data.Dataset.zip((trainX, trainY))
test_ds = tf.data.Dataset.zip((testX, testY))

print(train_ds.element_spec)
print(test_ds.element_spec)


Found 7308 files belonging to 1 classes.
Found 7308 files belonging to 1 classes.
Found 612 files belonging to 1 classes.
Found 612 files belonging to 1 classes.
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))


# Normalization

In [11]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))
test_ds = test_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


# Augmentation

In [12]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)
        self.augment_labels = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels


# Optimization

In [13]:
AUTOTUNE = tf.data.AUTOTUNE

train_batches = (
    train_ds
    .cache()
    .shuffle(BUFFER_SIZE)
    .prefetch(buffer_size=AUTOTUNE)
)

test_batches = (
    test_ds
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)


# Training Config

In [14]:
model_path = os.path.join(MODEL_PATH, EXP_NAME)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        model_path,
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        save_weights_only=True
    )
]

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE
)

model = UNet(IMG_SIZE)()

model.compile(
    loss=attention_mse,
    optimizer=optimizer,
    metrics=[attention_mae]
)

# try:
#     model.load_weights(model_path)
# except:
#     print("Checkpoint not found")
#     pass


# Training

In [15]:
history = model.fit(
    train_batches,
    validation_data=test_batches,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/300
Epoch 1: val_loss improved from inf to 0.05305, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 2/300
Epoch 2: val_loss improved from 0.05305 to 0.03071, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 3/300
Epoch 3: val_loss improved from 0.03071 to 0.02880, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 4/300
Epoch 4: val_loss improved from 0.02880 to 0.02871, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 5/300
Epoch 5: val_loss did not improve from 0.02871
Epoch 6/300
Epoch 6: val_loss improved from 0.02871 to 0.02784, saving model to /content/drive/MyDrive/Research/TAVI/Models/TAWSS_2_ECAP_MultiViewUNet_I256_B16_LR0.001
Epoch 7/300
Epoch 7: val_loss improved from 0.02784 to 0.02705, saving model to /content/drive/MyDrive/Re

# Save Model

In [16]:
timestamp = datetime.datetime.now().strftime('%b-%d-%I:%M%p')
if not os.path.exists(model_path):
    os.makedirs(model_path)

model.save(os.path.join(model_path, timestamp))



# Save Predictions

In [17]:
test_ds_unbatched = test_batches.unbatch()

pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

metrics = pd.DataFrame()

for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

    target = tf.squeeze(target).numpy()
    prediction = tf.squeeze(
        model.predict(
            tf.expand_dims(input, axis=0)
        )
    ).numpy()
    prediction[target == 1.0] = 1.0

    plt.figure(figsize=(7, 7))
    plt.imshow(target)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_T.png"))
    plt.close()

    plt.figure(figsize=(7, 7))
    plt.imshow(prediction)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}_P.png"))
    plt.close()


0it [00:00, ?it/s]



1it [00:02,  2.55s/it]



2it [00:02,  1.23s/it]



3it [00:03,  1.24it/s]



4it [00:03,  1.72it/s]



5it [00:03,  2.08it/s]



6it [00:04,  2.37it/s]



7it [00:04,  2.60it/s]



8it [00:04,  2.74it/s]



9it [00:04,  3.04it/s]



10it [00:05,  3.06it/s]



11it [00:05,  3.08it/s]



12it [00:05,  3.07it/s]



13it [00:06,  3.03it/s]



14it [00:06,  2.89it/s]



15it [00:06,  2.85it/s]



16it [00:07,  3.09it/s]



17it [00:07,  3.08it/s]



18it [00:07,  3.03it/s]



19it [00:08,  3.09it/s]



20it [00:08,  3.31it/s]



21it [00:08,  3.23it/s]



22it [00:09,  3.30it/s]



23it [00:09,  3.20it/s]



24it [00:09,  3.14it/s]



25it [00:10,  3.11it/s]



26it [00:10,  3.08it/s]



27it [00:10,  3.05it/s]



28it [00:11,  3.06it/s]



29it [00:11,  3.05it/s]



30it [00:11,  3.06it/s]



31it [00:12,  3.07it/s]



32it [00:12,  3.07it/s]



33it [00:12,  3.04it/s]



34it [00:12,  3.07it/s]



35it [00:13,  3.06it/s]



36it [00:13,  3.02it/s]



37it [00:13,  3.06it/s]



38it [00:14,  3.07it/s]



39it [00:14,  3.06it/s]



40it [00:14,  3.06it/s]



41it [00:15,  3.08it/s]



42it [00:15,  3.07it/s]



43it [00:15,  3.07it/s]



44it [00:16,  3.08it/s]



45it [00:16,  3.28it/s]



46it [00:16,  3.42it/s]



47it [00:17,  3.33it/s]



48it [00:17,  3.27it/s]



49it [00:17,  3.19it/s]



50it [00:18,  3.39it/s]



51it [00:18,  3.31it/s]



52it [00:18,  3.26it/s]



53it [00:18,  3.18it/s]



54it [00:19,  3.15it/s]



55it [00:19,  3.08it/s]



56it [00:19,  3.12it/s]



57it [00:20,  3.11it/s]



58it [00:20,  3.33it/s]



59it [00:20,  3.51it/s]



60it [00:21,  3.39it/s]



61it [00:21,  3.30it/s]



62it [00:21,  3.21it/s]



63it [00:22,  3.16it/s]



64it [00:22,  3.10it/s]



65it [00:22,  3.09it/s]



66it [00:23,  3.05it/s]



67it [00:23,  3.02it/s]



68it [00:23,  2.99it/s]



69it [00:24,  2.96it/s]



70it [00:24,  2.99it/s]



71it [00:24,  2.99it/s]



72it [00:25,  2.95it/s]



73it [00:25,  3.01it/s]



74it [00:25,  3.03it/s]



75it [00:26,  3.02it/s]



76it [00:26,  3.06it/s]



77it [00:26,  3.01it/s]



78it [00:27,  2.99it/s]



79it [00:27,  2.96it/s]



80it [00:27,  3.00it/s]



81it [00:28,  2.99it/s]



82it [00:28,  3.03it/s]



83it [00:28,  3.06it/s]



84it [00:29,  2.98it/s]



85it [00:29,  2.95it/s]



86it [00:29,  2.96it/s]



87it [00:30,  3.20it/s]



88it [00:30,  3.40it/s]



89it [00:30,  3.24it/s]



90it [00:30,  3.18it/s]



91it [00:31,  3.28it/s]



92it [00:31,  3.24it/s]



93it [00:31,  3.20it/s]



94it [00:32,  3.19it/s]



95it [00:32,  3.14it/s]



96it [00:32,  3.13it/s]



97it [00:33,  3.11it/s]



98it [00:33,  3.19it/s]



99it [00:33,  3.13it/s]



100it [00:34,  3.10it/s]



101it [00:34,  3.10it/s]



102it [00:34,  3.16it/s]



103it [00:35,  3.10it/s]



104it [00:35,  3.32it/s]



105it [00:35,  3.24it/s]



106it [00:35,  3.18it/s]



107it [00:36,  3.02it/s]



108it [00:36,  3.04it/s]



109it [00:37,  3.02it/s]



110it [00:37,  3.04it/s]



111it [00:37,  3.01it/s]



112it [00:38,  3.02it/s]



113it [00:38,  3.05it/s]



114it [00:38,  3.07it/s]



115it [00:38,  3.04it/s]



116it [00:39,  1.95it/s]



117it [00:40,  2.18it/s]



118it [00:40,  2.41it/s]



119it [00:40,  2.63it/s]



120it [00:41,  2.76it/s]



121it [00:41,  2.82it/s]



122it [00:41,  2.86it/s]



123it [00:42,  3.09it/s]



124it [00:42,  3.03it/s]



125it [00:42,  3.02it/s]



126it [00:43,  3.17it/s]



127it [00:43,  3.10it/s]



128it [00:43,  3.05it/s]



129it [00:44,  3.03it/s]



130it [00:44,  3.01it/s]



131it [00:44,  2.96it/s]



132it [00:45,  3.17it/s]



133it [00:45,  3.12it/s]



134it [00:45,  3.05it/s]



135it [00:46,  3.03it/s]



136it [00:46,  3.24it/s]



137it [00:46,  3.15it/s]



138it [00:46,  3.17it/s]



139it [00:47,  3.37it/s]



140it [00:47,  3.51it/s]



141it [00:47,  3.34it/s]



142it [00:48,  3.24it/s]



143it [00:48,  3.18it/s]



144it [00:48,  3.34it/s]



145it [00:48,  3.50it/s]



146it [00:49,  3.51it/s]



147it [00:49,  3.24it/s]



148it [00:50,  3.06it/s]



149it [00:50,  2.93it/s]



150it [00:50,  2.83it/s]



151it [00:51,  2.85it/s]



152it [00:51,  2.87it/s]



153it [00:51,  2.93it/s]



154it [00:52,  2.91it/s]



155it [00:52,  2.95it/s]



156it [00:52,  3.18it/s]



157it [00:52,  3.34it/s]



158it [00:53,  3.10it/s]



159it [00:53,  3.09it/s]



160it [00:54,  3.08it/s]



161it [00:54,  3.05it/s]



162it [00:54,  3.07it/s]



163it [00:54,  3.07it/s]



164it [00:55,  3.04it/s]



165it [00:55,  3.06it/s]



166it [00:55,  3.08it/s]



167it [00:56,  3.07it/s]



168it [00:56,  3.07it/s]



169it [00:56,  3.06it/s]



170it [00:57,  3.20it/s]



171it [00:57,  3.16it/s]



172it [00:57,  3.10it/s]



173it [00:58,  3.15it/s]



174it [00:58,  3.06it/s]



175it [00:58,  2.99it/s]



176it [00:59,  3.04it/s]



177it [00:59,  3.05it/s]



178it [00:59,  3.06it/s]



179it [01:00,  3.07it/s]



180it [01:00,  3.08it/s]



181it [01:00,  3.07it/s]



182it [01:01,  3.28it/s]



183it [01:01,  3.22it/s]



184it [01:01,  3.26it/s]



185it [01:02,  3.21it/s]



186it [01:02,  3.17it/s]



187it [01:02,  3.11it/s]



188it [01:03,  3.08it/s]



189it [01:03,  3.04it/s]



190it [01:03,  3.01it/s]



191it [01:04,  3.06it/s]



192it [01:04,  3.01it/s]



193it [01:04,  3.08it/s]



194it [01:04,  3.11it/s]



195it [01:05,  3.08it/s]



196it [01:05,  3.06it/s]



197it [01:05,  3.08it/s]



198it [01:06,  3.30it/s]



199it [01:06,  3.21it/s]



200it [01:06,  3.18it/s]



201it [01:07,  3.36it/s]



202it [01:07,  3.32it/s]



203it [01:07,  3.51it/s]



204it [01:08,  3.40it/s]



205it [01:08,  3.26it/s]



206it [01:08,  2.53it/s]



207it [01:09,  2.66it/s]



208it [01:09,  2.76it/s]



209it [01:09,  2.80it/s]



210it [01:10,  3.07it/s]



211it [01:10,  3.05it/s]



212it [01:10,  3.05it/s]



213it [01:11,  3.05it/s]



214it [01:11,  3.04it/s]



215it [01:11,  3.05it/s]



216it [01:12,  3.03it/s]



217it [01:12,  3.01it/s]



218it [01:12,  3.00it/s]



219it [01:13,  3.02it/s]



220it [01:13,  3.05it/s]



221it [01:13,  3.04it/s]



222it [01:14,  3.04it/s]



223it [01:14,  3.04it/s]



224it [01:14,  3.24it/s]



225it [01:15,  3.15it/s]



226it [01:15,  3.37it/s]



227it [01:15,  3.26it/s]



228it [01:16,  3.18it/s]



229it [01:16,  3.13it/s]



230it [01:16,  3.08it/s]



231it [01:17,  3.06it/s]



232it [01:17,  3.02it/s]



233it [01:17,  3.00it/s]



234it [01:18,  2.97it/s]



235it [01:18,  3.00it/s]



236it [01:18,  2.97it/s]



237it [01:19,  2.98it/s]



238it [01:19,  3.04it/s]



239it [01:19,  3.26it/s]



240it [01:19,  3.10it/s]



241it [01:20,  3.01it/s]



242it [01:20,  2.98it/s]



243it [01:21,  2.94it/s]



244it [01:21,  3.02it/s]



245it [01:21,  3.19it/s]



246it [01:21,  3.33it/s]



247it [01:22,  3.23it/s]



248it [01:22,  3.19it/s]



249it [01:22,  3.15it/s]



250it [01:23,  3.08it/s]



251it [01:23,  3.08it/s]



252it [01:23,  3.19it/s]



253it [01:24,  3.40it/s]



254it [01:24,  3.32it/s]



255it [01:24,  3.31it/s]



256it [01:24,  3.23it/s]



257it [01:25,  3.16it/s]



258it [01:25,  3.13it/s]



259it [01:25,  3.06it/s]



260it [01:26,  3.07it/s]



261it [01:26,  3.02it/s]



262it [01:26,  3.02it/s]



263it [01:27,  3.03it/s]



264it [01:27,  3.03it/s]



265it [01:28,  2.92it/s]



266it [01:28,  2.94it/s]



267it [01:28,  2.92it/s]



268it [01:29,  2.94it/s]



269it [01:29,  2.95it/s]



270it [01:29,  2.93it/s]



271it [01:30,  2.92it/s]



272it [01:30,  2.97it/s]



273it [01:30,  2.98it/s]



274it [01:31,  3.02it/s]



275it [01:31,  2.90it/s]



276it [01:31,  2.92it/s]



277it [01:32,  2.92it/s]



278it [01:32,  3.06it/s]



279it [01:32,  3.07it/s]



280it [01:32,  3.21it/s]



281it [01:33,  3.18it/s]



282it [01:33,  3.14it/s]



283it [01:33,  3.08it/s]



284it [01:34,  3.09it/s]



285it [01:34,  3.11it/s]



286it [01:34,  3.29it/s]



287it [01:35,  3.21it/s]



288it [01:35,  3.17it/s]



289it [01:35,  3.13it/s]



290it [01:36,  3.34it/s]



291it [01:36,  3.38it/s]



292it [01:36,  3.20it/s]



293it [01:37,  3.14it/s]



294it [01:37,  3.20it/s]



295it [01:37,  3.36it/s]



296it [01:37,  3.42it/s]



297it [01:38,  2.58it/s]



298it [01:38,  2.70it/s]



299it [01:39,  2.93it/s]



300it [01:39,  3.17it/s]



301it [01:39,  3.05it/s]



302it [01:40,  3.05it/s]



303it [01:40,  3.07it/s]



304it [01:40,  3.01it/s]



305it [01:41,  3.04it/s]



306it [01:41,  3.02it/s]



307it [01:41,  3.01it/s]



308it [01:42,  3.03it/s]



309it [01:42,  2.93it/s]



310it [01:42,  2.96it/s]



311it [01:43,  2.95it/s]



312it [01:43,  2.96it/s]



313it [01:43,  2.98it/s]



314it [01:44,  2.97it/s]



315it [01:44,  2.99it/s]



316it [01:44,  3.02it/s]



317it [01:45,  3.01it/s]



318it [01:45,  2.96it/s]



319it [01:45,  2.98it/s]



320it [01:46,  3.00it/s]



321it [01:46,  3.03it/s]



322it [01:46,  3.01it/s]



323it [01:47,  3.24it/s]



324it [01:47,  3.40it/s]



325it [01:47,  3.46it/s]



326it [01:47,  3.32it/s]



327it [01:48,  3.49it/s]



328it [01:48,  3.35it/s]



329it [01:48,  3.24it/s]



330it [01:49,  3.21it/s]



331it [01:49,  3.17it/s]



332it [01:49,  3.12it/s]



333it [01:50,  3.13it/s]



334it [01:50,  3.03it/s]



335it [01:50,  3.03it/s]



336it [01:51,  3.05it/s]



337it [01:51,  3.06it/s]



338it [01:51,  3.27it/s]



339it [01:51,  3.41it/s]



340it [01:52,  3.31it/s]



341it [01:52,  3.24it/s]



342it [01:52,  3.34it/s]



343it [01:53,  3.26it/s]



344it [01:53,  3.45it/s]



345it [01:53,  3.53it/s]



346it [01:53,  3.64it/s]



347it [01:54,  3.75it/s]



348it [01:54,  3.85it/s]



349it [01:54,  3.52it/s]



350it [01:55,  3.37it/s]



351it [01:55,  3.21it/s]



352it [01:55,  3.08it/s]



353it [01:56,  3.03it/s]



354it [01:56,  2.98it/s]



355it [01:56,  2.99it/s]



356it [01:57,  2.97it/s]



357it [01:57,  2.96it/s]



358it [01:57,  2.94it/s]



359it [01:58,  2.96it/s]



360it [01:58,  2.93it/s]



361it [01:58,  2.91it/s]



362it [01:59,  3.16it/s]



363it [01:59,  3.38it/s]



364it [01:59,  3.27it/s]



365it [02:00,  3.20it/s]



366it [02:00,  3.16it/s]



367it [02:00,  3.02it/s]



368it [02:01,  3.26it/s]



369it [02:01,  3.46it/s]



370it [02:01,  3.36it/s]



371it [02:01,  3.49it/s]



372it [02:02,  3.35it/s]



373it [02:02,  3.22it/s]



374it [02:02,  3.12it/s]



375it [02:03,  3.09it/s]



376it [02:03,  3.24it/s]



377it [02:03,  3.14it/s]



378it [02:04,  3.11it/s]



379it [02:04,  3.10it/s]



380it [02:04,  3.25it/s]



381it [02:05,  3.20it/s]



382it [02:05,  3.17it/s]



383it [02:05,  3.16it/s]



384it [02:06,  3.08it/s]



385it [02:06,  3.09it/s]



386it [02:06,  3.11it/s]



387it [02:07,  2.54it/s]



388it [02:07,  2.82it/s]



389it [02:07,  2.86it/s]



390it [02:08,  2.90it/s]



391it [02:08,  2.94it/s]



392it [02:08,  2.91it/s]



393it [02:09,  2.89it/s]



394it [02:09,  2.90it/s]



395it [02:09,  2.90it/s]



396it [02:10,  2.94it/s]



397it [02:10,  2.96it/s]



398it [02:10,  2.95it/s]



399it [02:11,  2.97it/s]



400it [02:11,  2.96it/s]



401it [02:11,  2.95it/s]



402it [02:12,  3.20it/s]



403it [02:12,  3.18it/s]



404it [02:12,  3.14it/s]



405it [02:13,  3.33it/s]



406it [02:13,  3.28it/s]



407it [02:13,  3.45it/s]



408it [02:13,  3.53it/s]



409it [02:14,  3.40it/s]



410it [02:14,  3.31it/s]



411it [02:14,  3.40it/s]



412it [02:15,  3.57it/s]



413it [02:15,  3.70it/s]



414it [02:15,  3.48it/s]



415it [02:15,  3.36it/s]



416it [02:16,  3.53it/s]



417it [02:16,  3.62it/s]



418it [02:16,  3.43it/s]



419it [02:17,  3.36it/s]



420it [02:17,  3.32it/s]



421it [02:17,  3.26it/s]



422it [02:18,  3.20it/s]



423it [02:18,  3.43it/s]



424it [02:18,  3.32it/s]



425it [02:18,  3.25it/s]



426it [02:19,  3.20it/s]



427it [02:19,  3.16it/s]



428it [02:19,  3.14it/s]



429it [02:20,  3.16it/s]



430it [02:20,  3.07it/s]



431it [02:20,  3.00it/s]



432it [02:21,  3.00it/s]



433it [02:21,  2.94it/s]



434it [02:21,  2.99it/s]



435it [02:22,  2.99it/s]



436it [02:22,  3.02it/s]



437it [02:22,  3.00it/s]



438it [02:23,  3.00it/s]



439it [02:23,  3.01it/s]



440it [02:23,  3.00it/s]



441it [02:24,  2.94it/s]



442it [02:24,  2.95it/s]



443it [02:24,  2.96it/s]



444it [02:25,  3.00it/s]



445it [02:25,  3.21it/s]



446it [02:25,  3.15it/s]



447it [02:26,  3.07it/s]



448it [02:26,  3.05it/s]



449it [02:26,  3.01it/s]



450it [02:27,  2.99it/s]



451it [02:27,  3.00it/s]



452it [02:27,  2.99it/s]



453it [02:28,  3.00it/s]



454it [02:28,  2.98it/s]



455it [02:28,  2.97it/s]



456it [02:29,  3.14it/s]



457it [02:29,  3.14it/s]



458it [02:29,  3.10it/s]



459it [02:30,  3.27it/s]



460it [02:30,  3.20it/s]



461it [02:30,  3.16it/s]



462it [02:31,  3.34it/s]



463it [02:31,  3.25it/s]



464it [02:31,  3.19it/s]



465it [02:32,  3.15it/s]



466it [02:32,  3.14it/s]



467it [02:32,  3.10it/s]



468it [02:32,  3.11it/s]



469it [02:33,  3.10it/s]



470it [02:33,  3.10it/s]



471it [02:33,  3.07it/s]



472it [02:34,  3.09it/s]



473it [02:34,  3.31it/s]



474it [02:34,  3.44it/s]



475it [02:35,  3.51it/s]



476it [02:35,  3.24it/s]



477it [02:35,  3.16it/s]



478it [02:36,  3.06it/s]



479it [02:36,  3.05it/s]



480it [02:36,  3.06it/s]



481it [02:37,  2.41it/s]



482it [02:37,  2.54it/s]



483it [02:38,  2.64it/s]



484it [02:38,  2.71it/s]



485it [02:38,  2.80it/s]



486it [02:39,  3.06it/s]



487it [02:39,  3.16it/s]



488it [02:39,  3.14it/s]



489it [02:39,  3.12it/s]



490it [02:40,  3.12it/s]



491it [02:40,  3.26it/s]



492it [02:40,  3.44it/s]



493it [02:41,  3.56it/s]



494it [02:41,  3.48it/s]



495it [02:41,  3.33it/s]



496it [02:42,  3.23it/s]



497it [02:42,  3.18it/s]



498it [02:42,  3.16it/s]



499it [02:42,  3.11it/s]



500it [02:43,  3.10it/s]



501it [02:43,  3.09it/s]



502it [02:43,  3.10it/s]



503it [02:44,  3.32it/s]



504it [02:44,  3.19it/s]



505it [02:44,  3.15it/s]



506it [02:45,  3.13it/s]



507it [02:45,  3.11it/s]



508it [02:45,  3.06it/s]



509it [02:46,  3.31it/s]



510it [02:46,  3.24it/s]



511it [02:46,  3.43it/s]



512it [02:46,  3.57it/s]



513it [02:47,  3.41it/s]



514it [02:47,  3.31it/s]



515it [02:47,  3.25it/s]



516it [02:48,  3.15it/s]



517it [02:48,  3.04it/s]



518it [02:48,  2.99it/s]



519it [02:49,  2.96it/s]



520it [02:49,  2.90it/s]



521it [02:49,  2.96it/s]



522it [02:50,  2.95it/s]



523it [02:50,  2.81it/s]



524it [02:51,  2.80it/s]



525it [02:51,  2.62it/s]



526it [02:51,  2.71it/s]



527it [02:52,  2.80it/s]



528it [02:52,  2.87it/s]



529it [02:52,  3.11it/s]



530it [02:53,  3.10it/s]



531it [02:53,  3.08it/s]



532it [02:53,  3.31it/s]



533it [02:53,  3.40it/s]



534it [02:54,  3.27it/s]



535it [02:54,  3.30it/s]



536it [02:54,  3.23it/s]



537it [02:55,  3.21it/s]



538it [02:55,  3.18it/s]



539it [02:55,  3.18it/s]



540it [02:56,  3.17it/s]



541it [02:56,  3.17it/s]



542it [02:56,  3.09it/s]



543it [02:57,  3.11it/s]



544it [02:57,  3.11it/s]



545it [02:57,  3.09it/s]



546it [02:58,  3.09it/s]



547it [02:58,  3.11it/s]



548it [02:58,  3.28it/s]



549it [02:59,  3.22it/s]



550it [02:59,  3.13it/s]



551it [02:59,  3.35it/s]



552it [02:59,  3.28it/s]



553it [03:00,  3.23it/s]



554it [03:00,  3.17it/s]



555it [03:00,  3.10it/s]



556it [03:01,  3.05it/s]



557it [03:01,  3.03it/s]



558it [03:01,  2.92it/s]



559it [03:02,  2.96it/s]



560it [03:02,  2.94it/s]



561it [03:02,  2.94it/s]



562it [03:03,  2.91it/s]



563it [03:03,  2.94it/s]



564it [03:04,  2.98it/s]



565it [03:04,  2.97it/s]



566it [03:04,  2.89it/s]



567it [03:05,  2.96it/s]



568it [03:05,  2.54it/s]



569it [03:05,  2.81it/s]



570it [03:06,  2.87it/s]



571it [03:06,  2.93it/s]



572it [03:06,  2.96it/s]



573it [03:07,  3.20it/s]



574it [03:07,  3.31it/s]



575it [03:07,  3.25it/s]



576it [03:07,  3.23it/s]



577it [03:08,  3.18it/s]



578it [03:08,  3.14it/s]



579it [03:08,  3.33it/s]



580it [03:09,  3.26it/s]



581it [03:09,  3.21it/s]



582it [03:09,  3.17it/s]



583it [03:10,  3.30it/s]



584it [03:10,  3.24it/s]



585it [03:10,  3.44it/s]



586it [03:10,  3.56it/s]



587it [03:11,  3.45it/s]



588it [03:11,  3.29it/s]



589it [03:11,  3.45it/s]



590it [03:12,  3.34it/s]



591it [03:12,  3.28it/s]



592it [03:12,  3.22it/s]



593it [03:13,  3.20it/s]



594it [03:13,  3.38it/s]



595it [03:13,  3.29it/s]



596it [03:14,  3.22it/s]



597it [03:14,  3.42it/s]



598it [03:14,  3.32it/s]



599it [03:14,  3.30it/s]



600it [03:15,  3.22it/s]



601it [03:15,  3.13it/s]



602it [03:15,  3.11it/s]



603it [03:16,  3.09it/s]



604it [03:16,  3.09it/s]



605it [03:16,  3.02it/s]



606it [03:17,  2.97it/s]



607it [03:17,  2.99it/s]



608it [03:17,  2.94it/s]



609it [03:18,  2.95it/s]



610it [03:18,  3.19it/s]



611it [03:18,  3.15it/s]



612it [03:19,  3.07it/s]


# Loss Curve

In [18]:
try:
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    plt.figure(figsize=(8, 6))
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Attention MAE')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.tight_layout()
    plt.savefig(os.path.join(model_path, timestamp + ".png"))
    plt.close()
except:
    print("Model did not finish training")


# Metrics

In [19]:
model.evaluate(test_batches)




[0.017388025298714638, 0.07057692855596542]