In [1]:
import sys
if ".." not in sys.path:
    sys.path.append("..")

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import wandb
import functools

from src.display import showarray
from src.mask import MaskGenerator
from src.datagen import DatasetFillGenerator
from src.augmenters import masked_channel_augmenter, masked_split_augmenter
from src.builders.unet import UNETBuilder
from src.builders.pcunet import PCUNETBuilder
from src.loss import (
    MaskedMAE, 
    MaskedGaussedSobelMAE, 
    GaussedSobelMAE,
    SSIMLoss, 
    CombinedLoss
)
from src.metrics import dice_coef, ssim_coef
from src.layers.pconv import PConv2D

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
config = {
  "learning_rate": 0.002,
  "epochs": 16,
  "batch_size": 8,
  "mask_gen_degree": "HEAVY",
  "mask_gen_min_width": 5,
  "mask_gen_max_width": 12,
  "use_partial_conv": True
}
model_params = {
  "n_filters": 16,
  "n_blocks": 2,
  "n_convs": 1,
  "activation": "gelu",
  "dropout_rate": 0.2,
}
loss_dict = {
  "masked_mae": MaskedMAE(),
  "masked_gaussed_sobel_mae": MaskedGaussedSobelMAE(),
  "gaussed_sobel_mae": GaussedSobelMAE(),
  "mae": tf.keras.losses.MeanAbsoluteError(),
  "ssim": SSIMLoss(),
}
loss_weights = {
  "masked_mae": 0,
  "masked_gaussed_sobel_mae": 0,
  "gaussed_sobel_mae": 0.2,
  "mae": 0.4,
  "ssim": 0.4,
}

loss_config = {key: (loss_fn, loss_weights[key]) for key, loss_fn in loss_dict.items() if loss_weights[key] > 0}

config.update({f"unet_{key}": val for key, val in model_params.items()})
config.update({f"loss_{key}_weight": val for key, val in loss_weights.items()})
wandb.init(project="cv3B-ii-ae-unet", entity="put_dl_team", config=config)

IMAGE_SIZE = (256, 256)
CHANNELS = 3
effective_channels = CHANNELS + 1
if wandb.config["use_partial_conv"]:
  effective_channels = CHANNELS

IM_SHAPE = IMAGE_SIZE + (effective_channels,)

BATCH_SIZE = wandb.config["batch_size"]
MASK_GEN_PARAM = {
    "degree": wandb.config["mask_gen_degree"],
    "min_width": wandb.config["mask_gen_min_width"],
    "max_width": wandb.config["mask_gen_max_width"],
}

mask_generator = MaskGenerator(*IMAGE_SIZE, CHANNELS, **MASK_GEN_PARAM)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msbartekt[0m ([33mput_dl_team[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
def scale(tensor: tf.Tensor, divisor: float = 255.0) -> tf.Tensor:
    return tensor / divisor

def recast_to_image(tensor: tf.Tensor) -> np.ndarray:
    return tf.cast(tensor[:, :, :3] * 255, tf.uint8).numpy()


ds_train, ds_valid = tf.keras.preprocessing.image_dataset_from_directory(
    directory="../data/1-8size", label_mode=None, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
    shuffle=True, seed=42, validation_split=0.1, subset="both",
)
ds_test = tf.keras.preprocessing.image_dataset_from_directory(
    directory="../data/test", label_mode=None, image_size=IMAGE_SIZE, batch_size=1,
    shuffle=False
)
ds_train = ds_train.map(scale)
ds_valid = ds_valid.map(scale)
ds_test = ds_test.map(scale)

for batch in ds_valid.take(1):
    showcase_images = batch[:5]

dataset_image_augmenter = functools.partial(masked_channel_augmenter, mask_generator=mask_generator)
builder_class = UNETBuilder
if wandb.config["use_partial_conv"]:
    if "n_convs" in model_params:
        del model_params["n_convs"]
    model_params["pconv_class"] = PConv2D
    dataset_image_augmenter = functools.partial(masked_split_augmenter, mask_generator=mask_generator)
    builder_class = PCUNETBuilder


class ImageFillCallback(tf.keras.callbacks.Callback):
    def __init__(self, model, showcase_images, augmenter):
        self.model = model
        self.input_data, self.showcase_images = augmenter(showcase_images)
        self.masked_images = self.input_data
        if len(self.masked_images) == 2:
            self.masked_images, self.masks = self.input_data

    def on_epoch_end(self, epoch, logs=None):
        nn_filled = self.model.predict(self.input_data)
        all_joint = []
        for i in range(5):
            masked = recast_to_image(self.masked_images[i])
            original_image = recast_to_image(self.showcase_images[i])
            filled_image = recast_to_image(nn_filled[i])
            joint = np.concatenate([masked, original_image, filled_image], axis=1)
            all_joint.append(joint)
        all_joint = np.concatenate(all_joint, axis=0)
        wandb.log({"sample_fill": wandb.Image(all_joint)})


np.random.seed(42)
train_generator = DatasetFillGenerator(ds_train, dataset_image_augmenter)
valid_generator = DatasetFillGenerator(ds_valid, dataset_image_augmenter)
test_generator = DatasetFillGenerator(ds_test, dataset_image_augmenter)

Found 2188 files belonging to 1 classes.
Using 1970 files for training.
Using 218 files for validation.
Found 141 files belonging to 1 classes.


In [5]:
builder = builder_class(IM_SHAPE, IM_SHAPE, **model_params)
model = builder.build()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 p_conv2d (PConv2D)             ((None, 256, 256, 1  880         ['input_1[0][0]',                
                                6),                               'input_2[0][0]']            

In [6]:
loss = CombinedLoss(loss_config)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=wandb.config["learning_rate"]),
    loss=loss,
    metrics=[dice_coef, ssim_coef],
)

In [7]:
model.fit(
    train_generator, 
    epochs=wandb.config["epochs"], 
    validation_data=valid_generator, 
    callbacks=[
        wandb.keras.WandbCallback(), 
        ImageFillCallback(model, showcase_images, dataset_image_augmenter)
    ]
)



Epoch 1/16



INFO:tensorflow:Assets written to: c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_145207-he8db57w\files\model-best\assets


INFO:tensorflow:Assets written to: c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_145207-he8db57w\files\model-best\assets
[34m[1mwandb[0m: Adding directory to artifact (c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_145207-he8db57w\files\model-best)... Done. 0.0s


Epoch 2/16



INFO:tensorflow:Assets written to: c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_145207-he8db57w\files\model-best\assets


INFO:tensorflow:Assets written to: c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_145207-he8db57w\files\model-best\assets
[34m[1mwandb[0m: Adding directory to artifact (c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_145207-he8db57w\files\model-best)... Done. 0.0s


Epoch 3/16

KeyboardInterrupt: 

In [9]:
callback = ImageFillCallback(model, showcase_images, dataset_image_augmenter)
callback.on_epoch_end(0)



In [None]:
test_result = model.evaluate(test_generator)
wandb.log({"test_loss": test_result[0], "test_dice": test_result[1], "test_ssim": test_result[2]})



In [None]:
test_mask_generator = MaskGenerator(*IMAGE_SIZE, CHANNELS, degree="HEAVY", min_width=10, max_width=24)

test_generator = DatasetFillGenerator(
    ds_test, 
    dataset_image_augmenter,
    shuffle=False
)

all_joint = []
for i in range(5):
    input_data, showcase_image = test_generator[i]
    nn_filled = model.predict(input_data)
    masked_image = input_data
    if len(input_data) == 2:
        masked_image = input_data[0]
    masked = recast_to_image(masked_image[0])
    original_image = recast_to_image(showcase_image[0])
    filled_image = recast_to_image(nn_filled[0])
    joint = np.concatenate([masked, original_image, filled_image], axis=1)
    all_joint.append(joint)
all_joint = np.concatenate(all_joint, axis=0)
wandb.log({"test_fill_result": wandb.Image(all_joint)})



In [None]:
model.save(os.path.join(wandb.run.dir, "model"))
wandb.finish(0)



INFO:tensorflow:Assets written to: c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_132348-zumeq7sa\files\model\assets


INFO:tensorflow:Assets written to: c:\Users\Bartosz\PycharmProjects\CV3_PROJECT\cv-image-inpainting\notebooks\wandb\run-20230125_132348-zumeq7sa\files\model\assets


0,1
dice_coef,▁▆▇▇▇██████▇▇██▇
epoch,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁
ssim_coef,▁▆▇▇▇▇██████████
test_dice,▁
test_loss,▁
test_ssim,▁
val_dice_coef,▃▁▂▆▇▄▄▁▅▄▄▆▆█▆▅
val_loss,█▆▅▄▄▃▂▂▂▂▁▂▂▁▁▁
val_ssim_coef,▁▃▅▅▅▆▇▇▇▇█▇▇███

0,1
best_epoch,15.0
best_val_loss,0.03014
dice_coef,0.64636
epoch,15.0
loss,0.03098
ssim_coef,0.93824
test_dice,0.63876
test_loss,0.03092
test_ssim,0.93865
val_dice_coef,0.64524
