In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "12"
# os.environ["TF_NUM_INTRAOP_THREADS"] = "12"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from IPython.display import clear_output
# import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import multiprocessing
# import albumentations as A
from tensorflow.python.keras import backend as K
from tqdm.auto import tqdm
import random
# from PIL import Image
# import cv2
from tensorflow.keras.layers import Dropout, Conv2D, Conv2DTranspose, MaxPooling2D, Input, concatenate
from multiprocessing import Pool
AUTOTUNE = tf.data.experimental.AUTOTUNE
num_cpus = multiprocessing.cpu_count()
my_devices = tf.config.experimental.list_physical_devices(device_type='CPU') 
tf.config.experimental.set_visible_devices(devices=my_devices, device_type='CPU')
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gpus = tf.config.list_physical_devices('GPU')
if gpus: 
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=4100)]
    )
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs")
cpus = tf.config.list_physical_devices('CPU')
logical_cpus = tf.config.list_logical_devices('CPU')
print(len(cpus), "Physical CPU,", len(logical_cpus), "Logical CPUs")

1 Physical GPU, 1 Logical GPUs
1 Physical CPU, 1 Logical CPUs


In [3]:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#   try:
#     for gpu in gpus:
#       tf.config.experimental.set_memory_growth(gpu, True)
#   except RuntimeError as e:
#     print(e)

In [3]:
IMG_HEIGHT = 384
IMG_WIDTH = 384
nb_filter = [32,64,128,256,512]
def dice_coef(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def bce_dice_loss(y_true, y_pred):
    return 0.5 * tf.keras.losses.binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)

In [4]:
X_path = "./input2"
Y_path = "./Output2"
img = sorted([os.path.join(dp, f) for dp, dn, filenames in os.walk(X_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
mask = sorted([os.path.join(dp, f) for dp, dn, filenames in os.walk(Y_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
X_path = "./images2"
Y_path = "./masks2"
img += sorted([os.path.join(dp, f) for dp, dn, filenames in os.walk(X_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
mask += sorted([os.path.join(dp, f) for dp, dn, filenames in os.walk(Y_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
image_list_train, image_list_val, mask_list_train, mask_list_val = train_test_split(img, mask, test_size=0.2, shuffle=True)

In [5]:
# def process_image(arg):
#     image_path, mask_path, i = arg
#     x = Image.open(image_path)
#     y = Image.open(mask_path)
#     for j in range(2):
#         transform = A.Compose([
#             A.HorizontalFlip(p=0.5),
#             A.ShiftScaleRotate(border_mode=cv2.BORDER_CONSTANT, 
#                                 scale_limit=0.3,
#                                 rotate_limit=(10, 30),
#                                 p=0.7),
#             # A.GridDistortion(p=0.5),
#             A.OpticalDistortion(p=0.5),
#             A.GaussianBlur(p=0.5),
#             A.Equalize(p=0.5),
#             A.RandomBrightnessContrast(p=0.5),
#             A.RandomGamma(p=0.5)
#         ])
#         transformed = transform(image=np.array(x), mask=np.array(y))

#         image_trans = transformed['image']
#         mask_trans = transformed['mask']
#         x = Image.fromarray(image_trans)
#         y = Image.fromarray(mask_trans)
#         x.save(f'./input2/{i}v{j}.jpg')
#         y.save(f'./Output2/{i}v{j}.png', 'PNG')

# if __name__ == '__main__':
#     img = sorted([str(os.path.join(dp, f)) for dp, dn, filenames in os.walk(X_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
#     mask = sorted([str(os.path.join(dp, f)) for dp, dn, filenames in os.walk(Y_path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
#     args_list = [(image_path, mask_path, i) for i, (image_path, mask_path) in enumerate(zip(img, mask))]
#     with Pool(processes=multiprocessing.cpu_count()) as pool:
#         list(tqdm(pool.imap(process_image, args_list), total=len(args_list)))

In [5]:
@tf.function
def load_image_and_mask(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH), method="bicubic")
    image = tf.cast(image, dtype=tf.float32) / 255.

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=3)
    mask = tf.image.resize(mask, (IMG_HEIGHT, IMG_WIDTH))
    mask = tf.reduce_all(mask == 0, axis=-1, keepdims=True)
    mask = tf.cast(not(mask), dtype=tf.float32)
    
    return image, mask

In [6]:
batch_size = 3
dataset_train = tf.data.Dataset.from_tensor_slices((img, mask))
dataset_train = dataset_train.shuffle(buffer_size=100000).map(load_image_and_mask, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

dataset_val = tf.data.Dataset.from_tensor_slices((image_list_val, mask_list_val))
dataset_val = dataset_val.shuffle(buffer_size=100000).map(load_image_and_mask, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [8]:
class SegmentationCallback(tf.keras.callbacks.Callback):
    def __init__(self, n_batches):
        test_Path = "/home/denis/Изображения/Веб-камера"
        self.n_batches = n_batches
        self.img_test = sorted([os.path.join(dp, f) for dp, dn, filenames in os.walk(test_Path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
        self.dataset_test = tf.data.Dataset.from_tensor_slices((self.img_test, self.img_test)).map(load_image_and_mask).batch(batch_size).prefetch(AUTOTUNE)

    def on_train_batch_end(self, batch, logs=None):
        if batch % self.n_batches == 0:
            fig = plt.figure(figsize=(36,13))
            for k, (x, _) in enumerate(self.dataset_test):
                y_pred = self.model.predict(x)
                for i in range(y_pred.shape[0]):
                    fig.add_subplot(3, batch_size*len(self.dataset_test), i+k*batch_size+1)
                    plt.imshow(y_pred[i])
                    plt.title('Predicted mask')
                    plt.axis('off')
                    mask = (y_pred[i] < 0.9)
                    mask = tf.where(mask, tf.ones_like(x[i]), x[i])
                    mask = tf.image.resize(mask, size=(384, 384), method='area')
                    mask = mask.numpy()
                    fig.add_subplot(3, batch_size*len(self.dataset_test), i+k*batch_size+batch_size*len(self.dataset_test)+1)
                    plt.imshow(mask, cmap="gray")
                    plt.title('Predicted image')
                    plt.axis('off')
                    fig.add_subplot(3, batch_size*len(self.dataset_test), i+k*batch_size+2*batch_size*len(self.dataset_test)+1)
                    plt.imshow(x[i], cmap="gray")
                    plt.title('Image')
                    plt.axis('off')
            clear_output()
            plt.show()
        if (batch % 1000) == 0 and batch >= 1000:
            self.model.save(f"save_model/model{batch}")
            # lst = random.sample(range(len(self.model.layers)),1)
            # for i in lst:
            #     self.model.layers[i].trainable = False
            

In [8]:
# for x,y in dataset_train:
#     fig = plt.figure(figsize=(36,13))
#     for i in range(len(x)):
#         # clear_output(wait=True)
#         fig.add_subplot(2, len(x), i+1)
#         plt.imshow(y[i], cmap="gray")
#         plt.axis('off')
#         fig.add_subplot(2, len(x), i+len(x)+1)
#         plt.imshow(x[i], cmap="gray")
#         plt.axis('off')
#         plt.show()

In [9]:
with strategy.scope():
    def convolution_block(
        block_input,
        num_filters=256,
        kernel_size=3,
        dilation_rate=1,
        padding="same",
        use_bias=False,
    ):
        x = tf.keras.layers.Conv2D(
            num_filters,
            kernel_size=kernel_size,
            dilation_rate=dilation_rate,
            padding=padding,
            use_bias=use_bias,
            kernel_initializer=tf.keras.initializers.HeNormal(),
        )(block_input)
        x = tf.keras.layers.BatchNormalization()(x)
        return tf.nn.relu(x)


    def DilatedSpatialPyramidPooling(dspp_input):
        dims = dspp_input.shape
        dspp_input = tf.keras.layers.BatchNormalization()(dspp_input)
        x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
        x = convolution_block(x, kernel_size=1, use_bias=True)
        out_pool = tf.keras.layers.UpSampling2D(
            size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="nearest" # nearest
        )(x)

        out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
        out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
        out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
        out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

        x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
        output = convolution_block(x, kernel_size=1)
        return output
    
    def DeeplabV3Plus(input_shape, num_classes):
        model_input = tf.keras.Input(shape=input_shape)
        resnet50 = tf.keras.applications.ResNet50(
            weights="imagenet", include_top=False, input_tensor=model_input
        )
        # print(len(resnet50.layers))
        x = resnet50.get_layer("conv4_block6_2_relu").output
        x = DilatedSpatialPyramidPooling(x)

        input_a = tf.keras.layers.UpSampling2D(
            size=(input_shape[0] // 4 // x.shape[1], input_shape[0] // 4 // x.shape[2]),
            interpolation="bilinear",
        )(x)
        input_b = resnet50.get_layer("conv2_block3_2_relu").output
        input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

        x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
        x = convolution_block(x)
        x = tf.keras.layers.Dropout(0.2)(x)
        x = convolution_block(x)
        x = tf.keras.layers.UpSampling2D(
            size=(input_shape[0] // x.shape[1], input_shape[0] // x.shape[2]),
            interpolation="bilinear",
        )(x)
        model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same", activation="sigmoid")(x)
        return tf.keras.Model(inputs=model_input, outputs=model_output)
    custom_objects = {'bce_dice_loss': bce_dice_loss}

# Load the model with the custom_objects argument
    with tf.keras.utils.custom_object_scope(custom_objects):
        model = tf.keras.models.load_model("save_model/model26000") # save_model800
        # model.layers[-2].interpolation = 'bilinear'
    # model = DeeplabV3Plus((IMG_HEIGHT, IMG_WIDTH, 3), 1)
    # scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
    #         initial_learning_rate=3*1e-4,
    #         decay_steps=20000,
    #         decay_rate=0.8
    #     )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=3*1e-5,use_ema=True, ema_momentum=0.5),
        loss=bce_dice_loss,
        metrics=tf.keras.metrics.BinaryIoU(),
    )

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

In [9]:
with strategy.scope():
    def unetplus(input=(384, 384, 256) , activation='elu'):
        inputs = Input(input)

        c1 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (inputs)
        # c1 = Dropout(0.5) (c1)
        c1 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c1)
        # c1 = Dropout(0.5) (c1)
        p1 = MaxPooling2D((2, 2), strides=(2, 2)) (c1)

        c2 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (p1)
        # c2 = Dropout(0.5) (c2)
        c2 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c2)
        # c2 = Dropout(0.5) (c2)
        p2 = MaxPooling2D((2, 2), strides=(2, 2)) (c2)

        up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(c2)
        conv1_2 = concatenate([up1_2, c1], name='merge12', axis=3)
        c3 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_2)
        # c3 = Dropout(0.5) (c3)
        c3 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c3)
        # c3 = Dropout(0.5) (c3)

        conv3_1 = Conv2D(128, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (p2)
        # conv3_1 = Dropout(0.5) (conv3_1)
        conv3_1 = Conv2D(128, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv3_1)
        # conv3_1 = Dropout(0.5) (conv3_1)
        pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)

        up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
        conv2_2 = concatenate([up2_2, c2], name='merge22', axis=3) #x10
        conv2_2 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv2_2)
        # conv2_2 = Dropout(0.5) (conv2_2)
        conv2_2 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv2_2)
        # conv2_2 = Dropout(0.5) (conv2_2)

        up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
        conv1_3 = concatenate([up1_3, c1, c3], name='merge13', axis=3)
        conv1_3 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_3)
        # conv1_3 = Dropout(0.5) (conv1_3)
        conv1_3 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_3)
        # conv1_3 = Dropout(0.5) (conv1_3)

        conv4_1 = Conv2D(256, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (pool3)
        # conv4_1 = Dropout(0.5) (conv4_1)
        conv4_1 = Conv2D(256, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv4_1)
        # conv4_1 = Dropout(0.5) (conv4_1)
        pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)

        up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
        conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=3) #x20
        conv3_2 = Conv2D(128, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv3_2)
        # conv3_2 = Dropout(0.5) (conv3_2)
        conv3_2 = Conv2D(128, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv3_2)
        conv3_2 = Dropout(0.5) (conv3_2)

        up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
        conv2_3 = concatenate([up2_3, c2, conv2_2], name='merge23', axis=3)
        conv2_3 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv2_3)
        # conv2_3 = Dropout(0.5) (conv2_3)
        conv2_3 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv2_3)
        # conv2_3 = Dropout(0.5) (conv2_3)

        up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
        conv1_4 = concatenate([up1_4, c1, c3, conv1_3], name='merge14', axis=3)
        conv1_4 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_4)
        conv1_4 = Dropout(0.5) (conv1_4)
        conv1_4 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_4)
        # conv1_4 = Dropout(0.5) (conv1_4)

        conv5_1 = Conv2D(512, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (pool4)
        # conv5_1 = Dropout(0.5) (conv5_1)
        conv5_1 = Conv2D(512, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv5_1)
        # conv5_1 = Dropout(0.5) (conv5_1)

        up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
        conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=3) #x30
        conv4_2 = Conv2D(256, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv4_2)
        # conv4_2 = Dropout(0.5) (conv4_2)
        conv4_2 = Conv2D(256, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv4_2)
        # conv4_2 = Dropout(0.5) (conv4_2)

        up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
        conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=3)
        conv3_3 = Conv2D(128, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv3_3)
        conv3_3 = Dropout(0.5) (conv3_3)
        conv3_3 = Conv2D(128, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv3_3)
        # conv3_3 = Dropout(0.5) (conv3_3)

        up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
        conv2_4 = concatenate([up2_4, c2, conv2_2, conv2_3], name='merge24', axis=3)
        conv2_4 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv2_4)
        # conv2_4 = Dropout(0.5) (conv2_4)
        conv2_4 = Conv2D(64, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv2_4)
        # conv2_4 = Dropout(0.5) (conv2_4)

        up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
        conv1_5 = concatenate([up1_5, c1, c3, conv1_3, conv1_4], name='merge15', axis=3)
        conv1_5 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_5)
        # conv1_5 = Dropout(0.5) (conv1_5)
        conv1_5 = Conv2D(32, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (conv1_5)
        # conv1_5 = Dropout(0.5) (conv1_5)

        nestnet_output_4 = Conv2D(1, (1, 1), activation='sigmoid', kernel_initializer = 'he_normal',  name='output_4', padding='same')(conv1_5)

        return tf.keras.Model([inputs], [nestnet_output_4])
    model2 = unetplus()
    # scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
    #         initial_learning_rate=1e-5,
    #         decay_steps=1000,
    #         decay_rate=0.8
    #     )
    # model2.compile(
    #     optimizer=tf.keras.optimizers.Adam(learning_rate=scheduler),
    #     loss=tf.keras.losses.BinaryCrossentropy(),
    #     metrics=tf.keras.metrics.BinaryIoU(),
    # )

In [5]:
class_weights = {0: 1., 1: 1.5}
history = model.fit(dataset_train, epochs=1, callbacks=[SegmentationCallback(n_batches=100)], use_multiprocessing=True, workers=os.cpu_count(), class_weight=class_weights)

In [6]:
test_Path = "/home/denis/Изображения/Веб-камера"
img_test = sorted([os.path.join(dp, f) for dp, dn, filenames in os.walk(test_Path) for f in filenames if os.path.splitext(f)[1] == '.png' or os.path.splitext(f)[1] == '.jpg'])
dataset_test = tf.data.Dataset.from_tensor_slices((img_test, img_test)).map(load_image_and_mask).batch(batch_size)
custom_objects = {'bce_dice_loss': bce_dice_loss}
for j in range(1000, 31001, 1000):
    
    with tf.keras.utils.custom_object_scope(custom_objects):
        model = tf.keras.models.load_model(f"save_model/model{j}")
    # model2 = tf.keras.models.load_model(f"save_model/model{j}")
    for layer in model.layers:
        layer.trainable=False
    for k, (x, _) in enumerate(dataset_test):
        # clear_output(100000)
        y_pred = model.predict(x)
        fig = plt.figure(figsize=(16,9))
        for i in range(y_pred.shape[0]):
            fig.add_subplot(3, y_pred.shape[0], i+1)
            plt.imshow(y_pred[i], cmap="gray")
            plt.title('Predicted mask')
            plt.axis('off')
            mask = (y_pred[i] < 0.9)
            mask = tf.where(mask, tf.ones_like(x[i]), x[i])
            mask = tf.image.resize(mask, size=(384, 384), method='area')
            mask = mask.numpy()
            fig.add_subplot(3, y_pred.shape[0], i+y_pred.shape[0]+1)
            plt.imshow(mask, cmap="gray")
            plt.title('Predicted image')
            plt.axis('off')
            fig.add_subplot(3, y_pred.shape[0], i+2*y_pred.shape[0]+1)
            plt.imshow(x[i], cmap="gray")
            plt.title('Image')
            plt.axis('off')
        # clear_output(wait=True)
        plt.show()

In [8]:
custom_objects = {'bce_dice_loss': bce_dice_loss}
model = tf.keras.models.load_model("save_end_model/model12000", custom_objects=custom_objects)
model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=3*1e-5,use_ema=True, ema_momentum=0.5),
        loss="bce",
        metrics=tf.keras.metrics.BinaryIoU(),
    )
model.save("mod", save_format='tf')

INFO:tensorflow:Assets written to: mod/assets


INFO:tensorflow:Assets written to: mod/assets
