<a href="https://colab.research.google.com/github/atick-faisal/MultiViewUNet-TAVI/blob/main/src/training/MultiViewUNetColab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab Info

In [None]:
from psutil import virtual_memory

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)


ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')


Mon Nov 20 16:13:00 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Fix for GDrive

In [None]:
!pip install -U --no-cache-dir gdown --pre > /dev/null


# Mount GDrive

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


# Download and Extract Dataset

In [None]:
!gdown "1xBO079FPIeE7T5VVsFwc8QeZxAfS4J9O"
!unzip -o "TAVI_REG_r17.zip" > /dev/null


Downloading...
From (uriginal): https://drive.google.com/uc?id=1xBO079FPIeE7T5VVsFwc8QeZxAfS4J9O
From (redirected): https://drive.google.com/uc?id=1xBO079FPIeE7T5VVsFwc8QeZxAfS4J9O&confirm=t&uuid=daacd8ec-e5cd-47f4-ae3d-3147ded9e601
To: /content/TAVI_REG_r17.zip
100% 370M/370M [00:03<00:00, 93.0MB/s]


# Imports

In [None]:
import os
import datetime
import matplotlib
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm

matplotlib.use('Agg')
plt.rcParams["font.size"] = 16

# Config

In [None]:
PROBLEM = "Curvature_2_Stress"

MODEL_NAME = "MultiViewUNet"
DATASET_PATH = "/content/Images/"
TRAIN_DIR = "Train/"
TEST_DIR = "Test/"
INPUT_DIR = PROBLEM.split("_2_")[0]
TARGET_DIR = PROBLEM.split("_2_")[1]
MODEL_PATH = "/content/drive/MyDrive/Research/TAVI/Models/"
PRED_PATH = "/content/drive/MyDrive/Research/TAVI/Predictions/"
IMG_SIZE = 256
BATCH_SIZE = 16
BUFFER_SIZE = 1000
VAL_SPLIT = 0.2
LEARNING_RATE = 0.001
N_EPOCHS = 300
PATIENCE = 30

EXP_NAME = f"{PROBLEM}_{MODEL_NAME}_I{IMG_SIZE}_B{BATCH_SIZE}_LR{LEARNING_RATE}"


# Architecture

In [None]:
class UNet:
    def __init__(
        self,
        img_size: int,
        n_channels: int = 3,
        width: int = 32,
        depth: int = 4,
        kernel_size: int = 3
    ):
        self.img_size = img_size
        self.n_channels = n_channels
        self.width = width
        self.depth = depth
        self.kernel_size = kernel_size

    @staticmethod
    def conv(
        x: tf.Tensor,
        filters: int,
        kernel_size: int
    ) -> tf.Tensor:
        for i in range(2):
            x = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                data_format="channels_last",
                dilation_rate=1,
                groups=1,
                activation=None,
                use_bias=True,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros",
            )(x)

            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def deconv(x: tf.Tensor, filters: int) -> tf.Tensor:
        x = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=2,
            strides=2,
            padding="same",
            output_padding=None,
            data_format=None,
            dilation_rate=1,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
        )(x)

        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

        return x

    @staticmethod
    def output(x: tf.Tensor) -> tf.Tensor:
        return tf.keras.layers.Conv2D(3, (1, 1), activation="sigmoid")(x)

    @staticmethod
    def pool(x: tf.Tensor, pool_size: int = 2) -> tf.Tensor:
        return tf.keras.layers.MaxPool2D(pool_size)(x)

    @staticmethod
    def dropout(x: tf.Tensor, amount: float = 0.5) -> tf.Tensor:
        return tf.keras.layers.Dropout(amount)(x)

    def __call__(self) -> tf.keras.Model:
        inputs = tf.keras.layers.Input(
            shape=(self.img_size, self.img_size, self.n_channels)
        )

        # scaled = tf.keras.layers.Rescaling(1./255.0, offset=0)(inputs)

        # ------------------ Downsampling ---------------------
        downsample_layers = []
        downsample_layers.append(
            self.conv(
                x=inputs,
                filters=self.width,
                kernel_size=self.kernel_size
            )
        )
        for i in range(1, self.depth):
            dropout_amount = 0.2 if i == 1 else 0.5
            filters = int((2 ** i) * self.width)
            downsample_layers.append(
                self.dropout(
                    self.pool(
                        self.conv(
                            x=downsample_layers[i - 1],
                            filters=filters,
                            kernel_size=self.kernel_size
                        )
                    ),
                    amount=dropout_amount
                )
            )

        # ------------------- Features --------------------
        n_features = int((2 ** self.depth) * self.width)
        self.features = self.pool(
            self.conv(
                x=downsample_layers[-1],
                filters=n_features,
                kernel_size=self.kernel_size
            )
        )

        # ------------------- Upsampling --------------------
        upsample_layers = []
        upsample_layers.append(self.features)
        for i in range(1, self.depth + 1):
            filters = int((2 ** (self.depth - i)) * self.width)
            upsample_layers.append(
                self.conv(
                    x=self.dropout(
                        tf.keras.layers.concatenate([
                            downsample_layers[self.depth - i],
                            self.deconv(
                                x=upsample_layers[i - 1],
                                filters=filters
                            )
                        ])
                    ),
                    filters=filters,
                    kernel_size=self.kernel_size
                )
            )

        # ---------------------- Output -----------------------
        outputs = self.output(upsample_layers[-1])

        return tf.keras.Model(inputs, outputs)


# Loss Functions / Metrics

In [None]:
def attention_mse(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.square(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


def attention_mae(y_true, y_pred):
    _y_true = y_true[y_true != 1.0]
    _y_pred = y_pred[y_true != 1.0]
    squared_difference = tf.abs(_y_true - _y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)


# DataLoader

In [None]:
def load_data_from_dir(path: str) -> tf.data.Dataset:
    return tf.keras.utils.image_dataset_from_directory(
        directory=path,
        labels=None,
        color_mode='rgb',
        batch_size=BATCH_SIZE,
        image_size=(IMG_SIZE, IMG_SIZE),
        shuffle=False,
        seed=42,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False
    )


# Load Dataset

In [None]:
trainX = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, INPUT_DIR))
trainY = load_data_from_dir(os.path.join(DATASET_PATH, TRAIN_DIR, TARGET_DIR))
testX = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, INPUT_DIR))
testY = load_data_from_dir(os.path.join(DATASET_PATH, TEST_DIR, TARGET_DIR))

train_ds = tf.data.Dataset.zip((trainX, trainY))
test_ds = tf.data.Dataset.zip((testX, testY))

print(train_ds.element_spec)
print(test_ds.element_spec)


Found 432 files belonging to 1 classes.
Found 432 files belonging to 1 classes.
Found 96 files belonging to 1 classes.
Found 96 files belonging to 1 classes.
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None))


# Normalization

In [None]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))
test_ds = test_ds.map(lambda x, y: (
    normalization_layer(x), normalization_layer(y)))


# Augmentation

In [None]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)
        self.augment_labels = tf.keras.layers.RandomZoom(
            (-0.1, -0.7), seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels


# Optimization

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_batches = (
    train_ds
    .cache()
    .shuffle(BUFFER_SIZE)
    .prefetch(buffer_size=AUTOTUNE)
)

test_batches = (
    test_ds
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)


# Training Config

In [None]:
model_path = os.path.join(MODEL_PATH, EXP_NAME)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        model_path,
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        save_weights_only=True
    )
]

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE
)

model = UNet(IMG_SIZE)()

model.compile(
    loss=attention_mse,
    optimizer=optimizer,
    metrics=[attention_mae]
)

# try:
#     model.load_weights(model_path)
# except:
#     print("Checkpoint not found")
#     pass


# Training

In [15]:
history = model.fit(
    train_batches,
    validation_data=test_batches,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/300
Epoch 1: val_loss improved from inf to 0.09253, saving model to /content/drive/MyDrive/Research/TAVI/Models/Curvature_2_Stress_MultiViewUNet_I256_B16_LR0.001
Epoch 2/300
Epoch 2: val_loss improved from 0.09253 to 0.07223, saving model to /content/drive/MyDrive/Research/TAVI/Models/Curvature_2_Stress_MultiViewUNet_I256_B16_LR0.001
Epoch 3/300
Epoch 3: val_loss improved from 0.07223 to 0.06256, saving model to /content/drive/MyDrive/Research/TAVI/Models/Curvature_2_Stress_MultiViewUNet_I256_B16_LR0.001
Epoch 4/300
Epoch 4: val_loss improved from 0.06256 to 0.06130, saving model to /content/drive/MyDrive/Research/TAVI/Models/Curvature_2_Stress_MultiViewUNet_I256_B16_LR0.001
Epoch 5/300
Epoch 5: val_loss improved from 0.06130 to 0.05619, saving model to /content/drive/MyDrive/Research/TAVI/Models/Curvature_2_Stress_MultiViewUNet_I256_B16_LR0.001
Epoch 6/300
Epoch 6: val_loss did not improve from 0.05619
Epoch 7/300
Epoch 7: val_loss did not improve from 0.05619
Epoch 8/300
Epoc

# Save Model

In [16]:
timestamp = datetime.datetime.now().strftime('%b-%d-%I:%M%p')
if not os.path.exists(model_path):
    os.makedirs(model_path)

model.save(os.path.join(model_path, timestamp))

In [17]:
model_path

'/content/drive/MyDrive/Research/TAVI/Models/Curvature_2_Stress_MultiViewUNet_I256_B16_LR0.001'

# Save Predictions

In [18]:
test_ds_unbatched = test_batches.unbatch()

pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

metrics = pd.DataFrame()

for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

    target = tf.squeeze(target)
    prediction = tf.squeeze(
        model.predict(
            tf.expand_dims(input, axis=0)
        )
    )

    channel_sum = tf.expand_dims(tf.reduce_sum(target, axis=-1), axis=-1)
    white_mask = tf.reduce_all(tf.equal(channel_sum, 3.0), axis=-1)
    expanded_mask = tf.expand_dims(white_mask, axis=-1)
    expanded_mask = tf.tile(expanded_mask, [1, 1, 3])
    # expanded_mask = tf.tile(expanded_mask, [1, 1])
    prediction = tf.where(expanded_mask, tf.ones_like(prediction), prediction)

    target_pil = Image.fromarray(np.array(target * 255.0, dtype=np.uint8))
    prediction_pil = Image.fromarray(np.array(prediction * 255.0, dtype=np.uint8))

    target_pil.save(os.path.join(pred_path, f"{idx}_T.png"))
    prediction_pil.save(os.path.join(pred_path, f"{idx}_P.png"))


0it [00:00, ?it/s]



1it [00:03,  3.26s/it]



2it [00:03,  1.41s/it]



3it [00:03,  1.22it/s]



4it [00:03,  1.86it/s]



5it [00:03,  2.57it/s]



6it [00:03,  3.25it/s]



7it [00:03,  4.10it/s]



8it [00:04,  4.92it/s]



9it [00:04,  5.24it/s]



10it [00:04,  5.49it/s]



11it [00:04,  6.13it/s]



12it [00:04,  6.64it/s]



13it [00:04,  6.63it/s]



14it [00:04,  7.20it/s]



15it [00:05,  6.95it/s]



16it [00:05,  6.79it/s]



17it [00:05,  7.31it/s]



18it [00:05,  7.74it/s]



19it [00:05,  7.97it/s]



20it [00:05,  7.71it/s]



21it [00:05,  7.82it/s]



22it [00:05,  8.21it/s]



23it [00:06,  7.58it/s]



24it [00:06,  7.96it/s]



25it [00:06,  7.19it/s]



26it [00:06,  6.61it/s]



27it [00:06,  7.02it/s]



28it [00:06,  7.06it/s]



29it [00:06,  7.44it/s]



30it [00:07,  7.82it/s]



31it [00:07,  7.38it/s]



32it [00:07,  7.45it/s]



33it [00:07,  7.79it/s]



34it [00:07,  7.44it/s]



35it [00:07,  6.90it/s]



36it [00:07,  7.32it/s]



37it [00:08,  6.56it/s]



38it [00:08,  6.99it/s]



39it [00:08,  7.31it/s]



40it [00:08,  7.75it/s]



41it [00:08,  7.80it/s]



42it [00:08,  8.05it/s]



43it [00:08,  8.37it/s]



44it [00:08,  8.65it/s]



45it [00:09,  7.99it/s]



46it [00:09,  8.22it/s]



47it [00:09,  8.29it/s]



48it [00:09,  8.56it/s]



49it [00:09,  8.67it/s]



50it [00:09,  7.37it/s]



51it [00:09,  6.88it/s]



52it [00:10,  6.37it/s]



53it [00:10,  6.03it/s]



54it [00:10,  5.52it/s]



55it [00:10,  5.37it/s]



56it [00:10,  5.35it/s]



57it [00:11,  5.35it/s]



58it [00:11,  5.17it/s]



59it [00:11,  5.24it/s]



60it [00:11,  5.16it/s]



61it [00:11,  5.53it/s]



62it [00:11,  5.49it/s]



63it [00:12,  5.40it/s]



64it [00:12,  5.34it/s]



65it [00:12,  5.11it/s]



66it [00:12,  5.16it/s]



67it [00:12,  5.18it/s]



68it [00:13,  5.24it/s]



69it [00:13,  5.72it/s]



70it [00:13,  5.88it/s]



71it [00:13,  5.70it/s]



72it [00:13,  5.72it/s]



73it [00:13,  5.54it/s]



74it [00:14,  5.44it/s]



75it [00:14,  5.90it/s]



76it [00:14,  6.56it/s]



77it [00:14,  7.08it/s]



78it [00:14,  6.73it/s]



79it [00:14,  7.33it/s]



80it [00:14,  7.81it/s]



81it [00:15,  7.47it/s]



82it [00:15,  7.20it/s]



83it [00:15,  7.55it/s]



84it [00:15,  7.89it/s]



85it [00:15,  7.45it/s]



86it [00:15,  7.74it/s]



87it [00:15,  8.06it/s]



88it [00:15,  7.33it/s]



89it [00:16,  7.75it/s]



90it [00:16,  8.11it/s]



91it [00:16,  8.36it/s]



92it [00:16,  7.58it/s]



93it [00:16,  7.76it/s]



94it [00:16,  8.04it/s]



95it [00:16,  8.15it/s]



96it [00:16,  5.67it/s]


In [19]:
# test_ds_unbatched = test_batches.unbatch()

# pred_path = os.path.join(PRED_PATH, EXP_NAME, timestamp)
# if not os.path.exists(pred_path):
#     os.makedirs(pred_path)

# metrics = pd.DataFrame()

# for idx, (input, target) in enumerate(tqdm(test_ds_unbatched)):

#     target = tf.squeeze(target).numpy()
#     prediction = tf.squeeze(
#         model.predict(
#             tf.expand_dims(input, axis=0)
#         )
#     )

#     channel_sum = tf.expand_dims(tf.reduce_sum(target, axis=-1), axis=-1)
#     white_mask = tf.reduce_all(tf.equal(channel_sum, 3.0), axis=-1)
#     expanded_mask = tf.expand_dims(white_mask, axis=-1)
#     expanded_mask = tf.tile(expanded_mask, [1, 1, 3])
#     prediction = tf.where(expanded_mask, tf.ones_like(prediction), prediction)

#     plt.figure(figsize=(7, 14))
#     plt.subplot(2, 1, 1)
#     plt.imshow(target)
#     plt.axis("off")
#     plt.tight_layout()

#     plt.subplot(2, 1, 2)
#     plt.imshow(prediction)
#     plt.axis("off")
#     plt.tight_layout()
#     plt.savefig(os.path.join(pred_path, f"{idx}.png"))
#     plt.close()


# Loss Curve

In [20]:
try:
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    plt.figure(figsize=(8, 6))
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Attention MAE')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.tight_layout()
    plt.savefig(os.path.join(model_path, timestamp + ".png"))
    plt.close()
except:
    print("Model did not finish training")


# Metrics

In [21]:
model.evaluate(test_batches)




[0.010710776783525944, 0.0552198700606823]