<a href="https://colab.research.google.com/github/atick-faisal/TAVI/blob/main/src/colab/MultiViewUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Download Dataset

In [2]:
!gdown "1NWZPFHtwIOe3MbpuFtc7qJh3Q9AWrvSd"
!unzip -o "Dataset_v12_TAWSS.zip" > /dev/null


Downloading...
From: https://drive.google.com/uc?id=1NWZPFHtwIOe3MbpuFtc7qJh3Q9AWrvSd
To: /content/Dataset_v12_TAWSS.zip
100% 723M/723M [00:02<00:00, 297MB/s]


# Imports

In [21]:
!pip install git+https://github.com/tensorflow/examples.git
!pip install image-similarity-measures

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting image-similarity-measures
  Downloading image_similarity_measures-0.3.5-py3-none-any.whl (9.1 kB)
Collecting pyfftw
  Downloading pyFFTW-0.13.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 4.1 MB/s 
[?25hCollecting phasepack
  Downloading phasepack-1.5.tar.gz (15 kB)
Collecting rasterio
  Downloading rasterio-1.2.10-cp37-cp37m-manylinux1_x86_64.whl (19.3 MB)
[K     |████████████████████████████████| 19.3 MB 51.7 MB/s 
Collecting snuggs>=1.4.1
  Downloading snuggs-1.4.7-py3-none-any.whl (5.4 kB)
Collecting cligj>=0.5
  Downloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Collecting affine
  Downloading affine-2.3.1-py2.py3-none-any.whl (16 kB)
Collecting click-plugins
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Building wheels for collected packages: phasepack
  Building wheel for phasepack (s

In [23]:
import os
import cv2
import random
import datetime
import matplotlib
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm
from numpy.typing import NDArray
from tensorflow_examples.models.pix2pix import pix2pix
from image_similarity_measures.quality_metrics import *

matplotlib.use('Agg')
plt.rcParams["font.size"] = 16


# Config

In [5]:
DATASET         = "v12"
MODEL_NAME      = "MobileUNet"
DATASET_PATH    = "/content/Images/"
TRAIN_DIR       = "Train/"
TEST_DIR        = "Val/"
INPUT_DIR       = "Input/"
TARGET_DIR      = "Target/"
MODEL_PATH      = "/content/drive/MyDrive/Research/TAVI/Models/"
PRED_PATH       = "/content/drive/MyDrive/Research/TAVI/Predictions/"

IMG_SIZE        = 128
BATCH_SIZE      = 8
BUFFER_SIZE     = 1000
VAL_SPLIT       = 0.2
LEARNING_RATE   = 0.001
N_EPOCHS        = 30
PATIENCE        = 30

EXP_NAME        = f"{MODEL_NAME}_I{IMG_SIZE}_B{BATCH_SIZE}_LR{LEARNING_RATE}_{DATASET}"


# Architecture

In [6]:
class UNet:
    def __init__(
        self,
        img_size: int,
        n_channels: int = 3,
        width: int = 32,
        depth: int = 4,
        kernel_size: int = 3
    ):
        self.img_size = img_size
        self.n_channels = n_channels
        self.width = width
        self.depth = depth
        self.kernel_size = kernel_size

    @staticmethod
    def conv(
        x: tf.Tensor,
        filters: int,
        kernel_size: int
    ) -> tf.Tensor:
        for i in range(2):
            x = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                data_format="channels_last",
                dilation_rate=1,
                groups=1,
                activation=None,
                use_bias=True,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros"
            )(x)

            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def deconv(x: tf.Tensor, filters: int) -> tf.Tensor:
        x = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=2,
            strides=2,
            padding="same",
            output_padding=None,
            data_format=None,
            dilation_rate=1,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros"
        )(x)

        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def output(x: tf.Tensor) -> tf.Tensor:
        return tf.keras.layers.Conv2D(3, (1, 1), activation="sigmoid")(x)

    @staticmethod
    def pool(x: tf.Tensor, pool_size: int = 2) -> tf.Tensor:
        return tf.keras.layers.MaxPool2D(pool_size)(x)

    def __call__(self) -> tf.keras.Model:
        inputs = tf.keras.layers.Input(
            shape=(self.img_size, self.img_size, self.n_channels)
        )

        # scaled = tf.keras.layers.Rescaling(1./255.0, offset=0)(inputs)

        # ------------------ Downsampling ---------------------
        downsample_layers = []
        downsample_layers.append(
            self.conv(
                x=inputs,
                filters=self.width,
                kernel_size=self.kernel_size
            )
        )
        for i in range(1, self.depth):
            filters = int((2 ** i) * self.width)
            downsample_layers.append(
                self.pool(
                    self.conv(
                        x=downsample_layers[i - 1],
                        filters=filters,
                        kernel_size=self.kernel_size
                    )
                )
            )

        # ------------------- Features --------------------
        n_features = int((2 ** self.depth) * self.width)
        self.features = self.pool(
            self.conv(
                x=downsample_layers[-1],
                filters=n_features,
                kernel_size=self.kernel_size
            )
        )

        # ------------------- Upsampling --------------------
        upsample_layers = []
        upsample_layers.append(self.features)
        for i in range(1, self.depth + 1):
            filters = int((2 ** (self.depth - i)) * self.width)
            upsample_layers.append(
                self.conv(
                    x=tf.keras.layers.concatenate([
                        downsample_layers[self.depth - i],
                        self.deconv(
                            x=upsample_layers[i - 1],
                            filters=filters
                        )
                    ]),
                    filters=filters,
                    kernel_size=self.kernel_size
                )
            )

        # ---------------------- Output -----------------------
        outputs = self.output(upsample_layers[-1])

        return tf.keras.Model(inputs, outputs)


In [7]:
class MobileUNet:
    def __init__(
        self,
        img_size: int,
    ):
        self.img_shape = (img_size, img_size, 3)

        self.base_model = tf.keras.applications.MobileNetV2(
            input_shape=self.img_shape, 
            include_top=False
        )

        # Use the activations of these layers
        self.layer_names = [
            'block_1_expand_relu',   # 64x64
            'block_3_expand_relu',   # 32x32
            'block_6_expand_relu',   # 16x16
            'block_13_expand_relu',  # 8x8
            'block_16_project',      # 4x4
        ]
        self.base_model_outputs = [self.base_model.get_layer(name).output for name in self.layer_names]

        # Create the feature extraction model
        self.down_stack = tf.keras.Model(inputs=self.base_model.input, outputs=self.base_model_outputs)

        self.down_stack.trainable = True

        self.up_stack = [
            pix2pix.upsample(512, 3),  # 4x4 -> 8x8
            pix2pix.upsample(256, 3),  # 8x8 -> 16x16
            pix2pix.upsample(128, 3),  # 16x16 -> 32x32
            pix2pix.upsample(64, 3),   # 32x32 -> 64x64
            pix2pix.upsample(32, 3),   # 32x32 -> 64x64
        ]

    def __call__(self):
        inputs = tf.keras.layers.Input(shape=self.img_shape)

        # Downsampling through the model
        skips = self.down_stack(inputs)
        x = skips[-1]
        skips = reversed(skips[:-1])

        # Upsampling and establishing the skip connections
        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            concat = tf.keras.layers.Concatenate()
            x = concat([x, skip])

        # This is the last layer of the model
        x = tf.keras.layers.Conv2DTranspose(
            filters=16, kernel_size=3, strides=2,
            padding='same')(x)

        x = tf.keras.layers.Conv2D(3, (1, 1), activation="sigmoid")(x)

        # x = last(x)

        return tf.keras.Model(inputs=inputs, outputs=x)

    

# Loss Functions

In [8]:
def attention_mse(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.square(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


def attention_mae(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.abs(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


# Utils

In [9]:
def load_data_from_dir(path: str) -> tf.data.Dataset:
    return tf.keras.utils.image_dataset_from_directory(
        directory=path,
        labels=None,
        color_mode='rgb',
        batch_size=BATCH_SIZE,
        image_size=(IMG_SIZE, IMG_SIZE),
        shuffle=True,
        seed=42,
        # validation_split=VAL_SPLIT,
        # subset=subset,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False
    )


# Dataloaders

In [10]:
trainX = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, INPUT_DIR))
trainY = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, TARGET_DIR))
testX = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, INPUT_DIR))
testY = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, TARGET_DIR))

train_ds = tf.data.Dataset.zip((trainX, trainY))
test_ds = tf.data.Dataset.zip((testX, testY))

print(train_ds.element_spec)
print(test_ds.element_spec)


Found 6624 files belonging to 1 classes.
Found 6624 files belonging to 1 classes.
Found 552 files belonging to 1 classes.
Found 552 files belonging to 1 classes.
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None))


# Normalization

In [11]:
# def normalize(input_image, output_image):
#     input_image = tf.cast(input_image, tf.float32) / 255.0
#     output_image = tf.cast(input_image, tf.float32) / 255.0
#     return input_image, output_image

In [12]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), normalization_layer(y)))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), normalization_layer(y)))

# Augmentation

In [13]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
        self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels

# Optimization

In [14]:
AUTOTUNE = tf.data.AUTOTUNE

train_batches = (
    train_ds
    .cache()
    # .shuffle(BUFFER_SIZE)
    # .batch(BATCH_SIZE)
    # .repeat()
    .map(Augment())
    .prefetch(buffer_size=AUTOTUNE)
)

test_batches = (test_ds)


# Training Config

In [15]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True
    )
]

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE
)

# model = UNet(IMG_SIZE)()
model = MobileUNet(IMG_SIZE)()

model.compile(
    loss=attention_mse,
    optimizer=optimizer,
    metrics=['mean_squared_error']
)


# Training

In [16]:
history = model.fit(
    train_batches,
    validation_data=test_batches,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


# Save Model

In [17]:
timestamp = datetime.datetime.now().strftime('%b-%d-%I:%M%p')
model_path = os.path.join(MODEL_PATH, EXP_NAME)
if not os.path.exists(model_path):
    os.makedirs(model_path)

model.save(os.path.join(model_path, timestamp))



# Save Predictions

In [28]:
test_ds_unbatched = test_batches.unbatch()

pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

metrics = pd.DataFrame()

for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

    target = target.numpy()
    prediction = tf.squeeze(
        model.predict(
            tf.expand_dims(input, axis=0)
            )
        ).numpy()
    prediction[target == 1.0] = 1.0

    _metrics = pd.DataFrame([[
            rmse(target, prediction),
            psnr(target, prediction),
            ssim(target, prediction),
            fsim(target, prediction),
            issm(target, prediction),
            sre(target, prediction),
            sam(target, prediction),
            uiq(target, prediction),
        ]],
        columns=["RMSE", "PSNR", "SSIM", "FSIM", "ISSM", "SRE", "SAM", "UIQ"]
    )

    metrics = pd.concat([metrics, _metrics], axis=0)

    if not idx % 40 == 0:
        continue

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(input)
    plt.axis("off")
    plt.title("INPUT")
    plt.subplot(1, 3, 2)
    plt.imshow(target)
    plt.axis("off")
    plt.title("TARGET")
    plt.subplot(1, 3, 3)
    plt.imshow(prediction)
    plt.axis("off")
    plt.title("PREDICTION")
    plt.tight_layout()
    plt.savefig(os.path.join(pred_path, f"{idx}.png"))
    plt.close()  


0it [00:00, ?it/s]



1it [00:05,  5.31s/it]



2it [00:09,  4.89s/it]



3it [00:14,  4.83s/it]



4it [00:19,  4.75s/it]



5it [00:24,  4.74s/it]



6it [00:28,  4.70s/it]



7it [00:33,  4.68s/it]



8it [00:37,  4.66s/it]



9it [00:42,  4.64s/it]



10it [00:47,  4.62s/it]



11it [00:51,  4.62s/it]



12it [00:56,  4.63s/it]



13it [01:00,  4.62s/it]



14it [01:05,  4.63s/it]



15it [01:10,  4.63s/it]



16it [01:14,  4.62s/it]



17it [01:19,  4.62s/it]



18it [01:24,  4.68s/it]



19it [01:28,  4.66s/it]



20it [01:33,  4.64s/it]



21it [01:38,  4.62s/it]



22it [01:42,  4.62s/it]



23it [01:47,  4.63s/it]



24it [01:51,  4.63s/it]



25it [01:56,  4.65s/it]



26it [02:01,  4.64s/it]



27it [02:05,  4.63s/it]



28it [02:10,  4.62s/it]



29it [02:15,  4.60s/it]



30it [02:19,  4.60s/it]



31it [02:24,  4.64s/it]



32it [02:28,  4.62s/it]



33it [02:33,  4.62s/it]



34it [02:38,  4.60s/it]



35it [02:42,  4.60s/it]



36it [02:47,  4.60s/it]



37it [02:51,  4.60s/it]



38it [02:56,  4.63s/it]



39it [03:01,  4.63s/it]



40it [03:05,  4.61s/it]



41it [03:10,  4.67s/it]



42it [03:15,  4.67s/it]



43it [03:19,  4.65s/it]



44it [03:24,  4.65s/it]



45it [03:29,  4.68s/it]



46it [03:33,  4.65s/it]



47it [03:38,  4.64s/it]



48it [03:43,  4.62s/it]



49it [03:47,  4.60s/it]



50it [03:52,  4.59s/it]



51it [03:56,  4.62s/it]



52it [04:01,  4.64s/it]



53it [04:06,  4.63s/it]



54it [04:10,  4.62s/it]



55it [04:15,  4.61s/it]



56it [04:19,  4.61s/it]



57it [04:24,  4.62s/it]



58it [04:29,  4.64s/it]



59it [04:33,  4.65s/it]



60it [04:38,  4.63s/it]



61it [04:43,  4.61s/it]



62it [04:47,  4.61s/it]



63it [04:52,  4.60s/it]



64it [04:56,  4.60s/it]



65it [05:01,  4.67s/it]



66it [05:06,  4.64s/it]



67it [05:10,  4.63s/it]



68it [05:15,  4.63s/it]



69it [05:20,  4.62s/it]



70it [05:24,  4.62s/it]



71it [05:29,  4.64s/it]



72it [05:34,  4.62s/it]



73it [05:38,  4.62s/it]



74it [05:43,  4.61s/it]



75it [05:47,  4.63s/it]



76it [05:52,  4.63s/it]



77it [05:57,  4.63s/it]



78it [06:02,  4.73s/it]



79it [06:06,  4.69s/it]



80it [06:11,  4.67s/it]



81it [06:16,  4.72s/it]



82it [06:20,  4.69s/it]



83it [06:25,  4.67s/it]



84it [06:30,  4.69s/it]



85it [06:34,  4.66s/it]



86it [06:39,  4.65s/it]



87it [06:43,  4.63s/it]



88it [06:48,  4.63s/it]



89it [06:53,  4.63s/it]



90it [06:57,  4.62s/it]



91it [07:02,  4.67s/it]



92it [07:07,  4.65s/it]



93it [07:11,  4.65s/it]



94it [07:16,  4.65s/it]



95it [07:21,  4.66s/it]



96it [07:25,  4.65s/it]



97it [07:30,  4.66s/it]



98it [07:35,  4.65s/it]



99it [07:39,  4.63s/it]



100it [07:44,  4.61s/it]



101it [07:48,  4.60s/it]



102it [07:54,  4.85s/it]



103it [07:58,  4.76s/it]



104it [08:03,  4.75s/it]



105it [08:08,  4.72s/it]



106it [08:12,  4.68s/it]



107it [08:17,  4.65s/it]



108it [08:21,  4.64s/it]



109it [08:26,  4.63s/it]



110it [08:31,  4.64s/it]



111it [08:35,  4.65s/it]



112it [08:40,  4.64s/it]



113it [08:45,  4.65s/it]



114it [08:49,  4.63s/it]



115it [08:54,  4.62s/it]



116it [08:58,  4.62s/it]



117it [09:03,  4.69s/it]



118it [09:08,  4.66s/it]



119it [09:12,  4.63s/it]



120it [09:17,  4.61s/it]



121it [09:22,  4.65s/it]



122it [09:26,  4.63s/it]



123it [09:31,  4.71s/it]



124it [09:36,  4.67s/it]



125it [09:40,  4.63s/it]



126it [09:45,  4.62s/it]



127it [09:50,  4.61s/it]



128it [09:54,  4.62s/it]



129it [09:59,  4.64s/it]



130it [10:04,  4.69s/it]



131it [10:08,  4.65s/it]



132it [10:13,  4.64s/it]



133it [10:17,  4.64s/it]



134it [10:22,  4.64s/it]



135it [10:27,  4.63s/it]



136it [10:31,  4.63s/it]



137it [10:36,  4.63s/it]



138it [10:41,  4.61s/it]



139it [10:45,  4.60s/it]



140it [10:50,  4.58s/it]



141it [10:54,  4.56s/it]



142it [10:59,  4.56s/it]



143it [11:04,  4.62s/it]



144it [11:08,  4.58s/it]



145it [11:13,  4.59s/it]



146it [11:17,  4.60s/it]



147it [11:22,  4.59s/it]



148it [11:26,  4.60s/it]



149it [11:31,  4.60s/it]



150it [11:36,  4.66s/it]



151it [11:40,  4.62s/it]



152it [11:45,  4.59s/it]



153it [11:49,  4.57s/it]



154it [11:54,  4.57s/it]



155it [11:58,  4.55s/it]



156it [12:03,  4.59s/it]



157it [12:08,  4.57s/it]



158it [12:12,  4.58s/it]



159it [12:17,  4.57s/it]



160it [12:21,  4.57s/it]



161it [12:26,  4.63s/it]



162it [12:31,  4.63s/it]



163it [12:36,  4.66s/it]



164it [12:40,  4.66s/it]



165it [12:45,  4.64s/it]



166it [12:49,  4.62s/it]



167it [12:54,  4.61s/it]



168it [12:59,  4.62s/it]



169it [13:03,  4.62s/it]



170it [13:08,  4.63s/it]



171it [13:12,  4.62s/it]



172it [13:17,  4.62s/it]



173it [13:22,  4.61s/it]



174it [13:26,  4.61s/it]



175it [13:31,  4.60s/it]



176it [13:35,  4.59s/it]



177it [13:40,  4.64s/it]



178it [13:45,  4.63s/it]



179it [13:49,  4.63s/it]



180it [13:54,  4.63s/it]



181it [13:59,  4.64s/it]



182it [14:03,  4.64s/it]



183it [14:08,  4.66s/it]



184it [14:13,  4.64s/it]



185it [14:17,  4.63s/it]



186it [14:22,  4.62s/it]



187it [14:26,  4.63s/it]



188it [14:31,  4.67s/it]



189it [14:36,  4.65s/it]



190it [14:41,  4.70s/it]



191it [14:45,  4.67s/it]



192it [14:50,  4.65s/it]



193it [14:55,  4.64s/it]



194it [14:59,  4.62s/it]



195it [15:04,  4.61s/it]



196it [15:08,  4.66s/it]



197it [15:13,  4.67s/it]



198it [15:18,  4.67s/it]



199it [15:22,  4.67s/it]



200it [15:27,  4.65s/it]



201it [15:32,  4.70s/it]



202it [15:36,  4.67s/it]



203it [15:41,  4.72s/it]



204it [15:46,  4.67s/it]



205it [15:50,  4.65s/it]



206it [15:55,  4.64s/it]



207it [16:00,  4.63s/it]



208it [16:04,  4.64s/it]



209it [16:09,  4.65s/it]



210it [16:14,  4.65s/it]



211it [16:18,  4.63s/it]



212it [16:23,  4.62s/it]



213it [16:28,  4.64s/it]



214it [16:32,  4.64s/it]



215it [16:37,  4.66s/it]



216it [16:42,  4.67s/it]



217it [16:46,  4.66s/it]



218it [16:51,  4.65s/it]



219it [16:55,  4.64s/it]



220it [17:00,  4.63s/it]



221it [17:05,  4.62s/it]



222it [17:09,  4.63s/it]



223it [17:14,  4.68s/it]



224it [17:19,  4.69s/it]



225it [17:24,  4.69s/it]



226it [17:28,  4.67s/it]



227it [17:33,  4.65s/it]



228it [17:37,  4.64s/it]



229it [17:42,  4.68s/it]



230it [17:47,  4.65s/it]



231it [17:51,  4.64s/it]



232it [17:56,  4.66s/it]



233it [18:01,  4.66s/it]



234it [18:05,  4.65s/it]



235it [18:10,  4.65s/it]



236it [18:15,  4.67s/it]



237it [18:19,  4.64s/it]



238it [18:24,  4.63s/it]



239it [18:29,  4.63s/it]



240it [18:33,  4.63s/it]



241it [18:38,  4.66s/it]



242it [18:43,  4.71s/it]



243it [18:47,  4.67s/it]



244it [18:52,  4.64s/it]



245it [18:56,  4.61s/it]



246it [19:01,  4.60s/it]



247it [19:06,  4.59s/it]



248it [19:10,  4.59s/it]



249it [19:15,  4.65s/it]



250it [19:20,  4.65s/it]



251it [19:24,  4.65s/it]



252it [19:29,  4.69s/it]



253it [19:34,  4.69s/it]



254it [19:38,  4.67s/it]



255it [19:43,  4.70s/it]



256it [19:48,  4.67s/it]



257it [19:52,  4.65s/it]



258it [19:57,  4.63s/it]



259it [20:01,  4.62s/it]



260it [20:06,  4.62s/it]



261it [20:11,  4.61s/it]



262it [20:16,  4.68s/it]



263it [20:20,  4.66s/it]



264it [20:25,  4.63s/it]



265it [20:29,  4.65s/it]



266it [20:34,  4.65s/it]



267it [20:39,  4.66s/it]



268it [20:43,  4.67s/it]



269it [20:48,  4.65s/it]



270it [20:53,  4.63s/it]



271it [20:57,  4.61s/it]



272it [21:02,  4.61s/it]



273it [21:06,  4.62s/it]



274it [21:11,  4.61s/it]



275it [21:16,  4.70s/it]



276it [21:20,  4.66s/it]



277it [21:25,  4.64s/it]



278it [21:30,  4.62s/it]



279it [21:34,  4.63s/it]



280it [21:39,  4.62s/it]



281it [21:44,  4.69s/it]



282it [21:48,  4.67s/it]



283it [21:53,  4.65s/it]



284it [21:58,  4.65s/it]



285it [22:02,  4.65s/it]



286it [22:07,  4.64s/it]



287it [22:11,  4.62s/it]



288it [22:16,  4.67s/it]



289it [22:21,  4.65s/it]



290it [22:25,  4.64s/it]



291it [22:30,  4.63s/it]



292it [22:35,  4.61s/it]



293it [22:39,  4.62s/it]



294it [22:44,  4.61s/it]



295it [22:49,  4.66s/it]



296it [22:53,  4.64s/it]



297it [22:58,  4.64s/it]



298it [23:02,  4.62s/it]



299it [23:07,  4.60s/it]



300it [23:12,  4.60s/it]



301it [23:16,  4.61s/it]



302it [23:21,  4.63s/it]



303it [23:26,  4.65s/it]



304it [23:30,  4.65s/it]



305it [23:35,  4.64s/it]



306it [23:39,  4.63s/it]



307it [23:44,  4.63s/it]



308it [23:49,  4.64s/it]



309it [23:53,  4.62s/it]



310it [23:58,  4.61s/it]



311it [24:02,  4.60s/it]



312it [24:07,  4.60s/it]



313it [24:12,  4.61s/it]



314it [24:16,  4.60s/it]



315it [24:21,  4.65s/it]



316it [24:26,  4.64s/it]



317it [24:30,  4.69s/it]



318it [24:35,  4.66s/it]



319it [24:40,  4.65s/it]



320it [24:44,  4.65s/it]



321it [24:49,  4.75s/it]



322it [24:54,  4.71s/it]



323it [24:59,  4.68s/it]



324it [25:03,  4.66s/it]



325it [25:08,  4.64s/it]



326it [25:12,  4.62s/it]



327it [25:17,  4.62s/it]



328it [25:22,  4.67s/it]



329it [25:26,  4.66s/it]



330it [25:31,  4.65s/it]



331it [25:36,  4.63s/it]



332it [25:40,  4.62s/it]



333it [25:45,  4.61s/it]



334it [25:50,  4.66s/it]



335it [25:54,  4.68s/it]



336it [25:59,  4.67s/it]



337it [26:04,  4.67s/it]



338it [26:08,  4.66s/it]



339it [26:13,  4.65s/it]



340it [26:17,  4.63s/it]



341it [26:22,  4.66s/it]



342it [26:27,  4.65s/it]



343it [26:31,  4.64s/it]



344it [26:36,  4.63s/it]



345it [26:41,  4.62s/it]



346it [26:45,  4.62s/it]



347it [26:50,  4.65s/it]



348it [26:55,  4.67s/it]



349it [26:59,  4.64s/it]



350it [27:04,  4.62s/it]



351it [27:08,  4.62s/it]



352it [27:13,  4.60s/it]



353it [27:18,  4.61s/it]



354it [27:22,  4.63s/it]



355it [27:27,  4.64s/it]



356it [27:32,  4.64s/it]



357it [27:36,  4.62s/it]



358it [27:41,  4.61s/it]



359it [27:45,  4.60s/it]



360it [27:50,  4.60s/it]



361it [27:55,  4.68s/it]



362it [27:59,  4.65s/it]



363it [28:04,  4.63s/it]



364it [28:09,  4.62s/it]



365it [28:13,  4.60s/it]



366it [28:18,  4.60s/it]



367it [28:22,  4.63s/it]



368it [28:27,  4.63s/it]



369it [28:32,  4.63s/it]



370it [28:36,  4.63s/it]



371it [28:41,  4.63s/it]



372it [28:46,  4.63s/it]



373it [28:50,  4.62s/it]



374it [28:55,  4.64s/it]



375it [28:59,  4.62s/it]



376it [29:04,  4.61s/it]



377it [29:09,  4.61s/it]



378it [29:13,  4.61s/it]



379it [29:18,  4.59s/it]



380it [29:23,  4.65s/it]



381it [29:27,  4.66s/it]



382it [29:32,  4.65s/it]



383it [29:36,  4.64s/it]



384it [29:41,  4.62s/it]



385it [29:46,  4.62s/it]



386it [29:50,  4.60s/it]



387it [29:55,  4.63s/it]



388it [30:00,  4.62s/it]



389it [30:04,  4.63s/it]



390it [30:09,  4.63s/it]



391it [30:13,  4.65s/it]



392it [30:18,  4.64s/it]



393it [30:23,  4.68s/it]



394it [30:27,  4.66s/it]



395it [30:32,  4.65s/it]



396it [30:37,  4.63s/it]



397it [30:41,  4.63s/it]



398it [30:46,  4.62s/it]



399it [30:50,  4.61s/it]



400it [30:55,  4.67s/it]



401it [31:00,  4.71s/it]



402it [31:05,  4.68s/it]



403it [31:09,  4.66s/it]



404it [31:14,  4.64s/it]



405it [31:18,  4.61s/it]



406it [31:23,  4.63s/it]



407it [31:28,  4.64s/it]



408it [31:33,  4.66s/it]



409it [31:37,  4.66s/it]



410it [31:42,  4.64s/it]



411it [31:46,  4.63s/it]



412it [31:51,  4.62s/it]



413it [31:56,  4.63s/it]



414it [32:00,  4.62s/it]



415it [32:05,  4.63s/it]



416it [32:09,  4.62s/it]



417it [32:14,  4.63s/it]



418it [32:19,  4.61s/it]



419it [32:23,  4.62s/it]



420it [32:28,  4.64s/it]



421it [32:33,  4.63s/it]



422it [32:37,  4.61s/it]



423it [32:42,  4.62s/it]



424it [32:46,  4.62s/it]



425it [32:51,  4.64s/it]



426it [32:56,  4.63s/it]



427it [33:01,  4.67s/it]



428it [33:05,  4.64s/it]



429it [33:10,  4.63s/it]



430it [33:14,  4.62s/it]



431it [33:19,  4.62s/it]



432it [33:23,  4.60s/it]



433it [33:28,  4.63s/it]



434it [33:33,  4.63s/it]



435it [33:37,  4.63s/it]



436it [33:42,  4.62s/it]



437it [33:47,  4.61s/it]



438it [33:51,  4.60s/it]



439it [33:56,  4.60s/it]



440it [34:01,  4.77s/it]



441it [34:06,  4.79s/it]



442it [34:10,  4.75s/it]



443it [34:15,  4.71s/it]



444it [34:20,  4.67s/it]



445it [34:24,  4.66s/it]



446it [34:29,  4.71s/it]



447it [34:34,  4.67s/it]



448it [34:38,  4.64s/it]



449it [34:43,  4.63s/it]



450it [34:47,  4.61s/it]



451it [34:52,  4.61s/it]



452it [34:57,  4.60s/it]



453it [35:01,  4.65s/it]



454it [35:06,  4.64s/it]



455it [35:11,  4.63s/it]



456it [35:15,  4.62s/it]



457it [35:20,  4.62s/it]



458it [35:24,  4.62s/it]



459it [35:29,  4.65s/it]



460it [35:34,  4.66s/it]



461it [35:38,  4.65s/it]



462it [35:43,  4.63s/it]



463it [35:48,  4.62s/it]



464it [35:52,  4.61s/it]



465it [35:57,  4.61s/it]



466it [36:02,  4.64s/it]



467it [36:06,  4.62s/it]



468it [36:11,  4.61s/it]



469it [36:15,  4.60s/it]



470it [36:20,  4.60s/it]



471it [36:24,  4.59s/it]



472it [36:29,  4.62s/it]



473it [36:34,  4.69s/it]



474it [36:39,  4.66s/it]



475it [36:43,  4.65s/it]



476it [36:48,  4.65s/it]



477it [36:53,  4.65s/it]



478it [36:57,  4.64s/it]



479it [37:02,  4.68s/it]



480it [37:07,  4.66s/it]



481it [37:11,  4.70s/it]



482it [37:16,  4.68s/it]



483it [37:21,  4.71s/it]



484it [37:25,  4.68s/it]



485it [37:30,  4.69s/it]



486it [37:35,  4.72s/it]



487it [37:39,  4.69s/it]



488it [37:44,  4.67s/it]



489it [37:49,  4.65s/it]



490it [37:53,  4.64s/it]



491it [37:58,  4.62s/it]



492it [38:03,  4.64s/it]



493it [38:07,  4.64s/it]



494it [38:12,  4.64s/it]



495it [38:16,  4.63s/it]



496it [38:21,  4.62s/it]



497it [38:26,  4.61s/it]



498it [38:30,  4.63s/it]



499it [38:35,  4.66s/it]



500it [38:40,  4.64s/it]



501it [38:44,  4.62s/it]



502it [38:49,  4.62s/it]



503it [38:53,  4.61s/it]



504it [38:58,  4.60s/it]



505it [39:03,  4.62s/it]



506it [39:07,  4.61s/it]



507it [39:12,  4.61s/it]



508it [39:16,  4.62s/it]



509it [39:21,  4.61s/it]



510it [39:26,  4.59s/it]



511it [39:30,  4.63s/it]



512it [39:35,  4.65s/it]



513it [39:40,  4.63s/it]



514it [39:44,  4.62s/it]



515it [39:49,  4.61s/it]



516it [39:53,  4.61s/it]



517it [39:58,  4.60s/it]



518it [40:03,  4.65s/it]



519it [40:07,  4.63s/it]



520it [40:12,  4.63s/it]



521it [40:17,  4.69s/it]



522it [40:21,  4.66s/it]



523it [40:26,  4.64s/it]



524it [40:31,  4.64s/it]



525it [40:36,  4.71s/it]



526it [40:40,  4.67s/it]



527it [40:45,  4.64s/it]



528it [40:49,  4.64s/it]



529it [40:54,  4.64s/it]



530it [40:59,  4.64s/it]



531it [41:03,  4.65s/it]



532it [41:08,  4.64s/it]



533it [41:12,  4.62s/it]



534it [41:17,  4.61s/it]



535it [41:22,  4.61s/it]



536it [41:26,  4.60s/it]



537it [41:31,  4.61s/it]



538it [41:36,  4.68s/it]



539it [41:40,  4.66s/it]



540it [41:45,  4.64s/it]



541it [41:49,  4.62s/it]



542it [41:54,  4.61s/it]



543it [41:59,  4.60s/it]



544it [42:03,  4.59s/it]



545it [42:08,  4.66s/it]



546it [42:13,  4.66s/it]



547it [42:17,  4.67s/it]



548it [42:22,  4.65s/it]



549it [42:27,  4.65s/it]



550it [42:31,  4.64s/it]



551it [42:36,  4.70s/it]



552it [42:41,  4.64s/it]


In [19]:
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 6))
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Attention MAE')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.tight_layout()
plt.savefig(os.path.join(model_path, timestamp + ".png"))
plt.close()

In [20]:
history

<keras.callbacks.History at 0x7f9656023ad0>

In [29]:
metrics

Unnamed: 0,RMSE,PSNR,SSIM,FSIM,ISSM,SRE,SAM,UIQ
0,0.000005,106.330986,1.0,0.986872,0.0,40.349758,0.393067,0.864831
0,0.000006,104.309990,1.0,0.990317,0.0,40.868201,0.616709,0.848967
0,0.000008,101.961196,1.0,0.979266,0.0,40.913385,0.784501,0.865688
0,0.000005,106.366634,1.0,0.984267,0.0,38.712033,0.349486,0.917367
0,0.000007,102.680523,1.0,0.974752,0.0,38.107613,0.480382,0.889561
...,...,...,...,...,...,...,...,...
0,0.000005,106.801753,1.0,0.990090,0.0,40.926667,0.372600,0.892624
0,0.000005,105.848274,1.0,0.984536,0.0,41.105306,0.470067,0.876452
0,0.000006,104.830704,1.0,0.980286,0.0,39.831557,0.458005,0.843419
0,0.000006,104.464157,1.0,0.986190,0.0,38.832027,0.454618,0.855725
