In [None]:
!pip install segmentation-models
!pip install --upgrade tensorflow

In [None]:
import sys
sys.path.append('/kaggle/input/helper-scripts')
sys.path.append('/kaggle/input/fetch1')

In [None]:
%env SM_FRAMEWORK=tf.keras
import segmentation_models as sm
import tensorflow as tf
tf.config.run_functions_eagerly(True)

from fetch_data1 import fetch
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
from visualise import visualize, denormalize

In [None]:
base_dir = '/kaggle/input/augmented2/ISIC 2016 for segmentation (augmented)/'
batch_size = 4 
input_size = (224, 224)
num_epochs = 40
shuffle = False
ratio = 1
grp = 0
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
model_save = '2016_extend_best_model.h5'

callbacks = [ModelCheckpoint('./'+model_save, 
                             monitor = "val_iou_score",
                             save_weights_only=True, 
                             save_best_only=True, 
                             mode='max',
                             initial_value_threshold = 0.0),
            ReduceLROnPlateau(monitor = "val_iou_score",
                             factor = 0.5,
                             patience = 3,
                             verbose = 1,
                             mode = 'max')]

In [None]:
train_dataset, validation_dataset, val_paths = fetch(base_dir, input_size, grp, batch_size, shuffle, ratio)

In [None]:
model = sm.Unet(backbone_name = "densenet201",
                   input_shape=(224, 224, 3))

model.compile(optimizer=Adam(learning_rate = 8e-6), 
              loss=sm.losses.bce_dice_loss, 
              metrics=metrics)

print(len(model.layers))
model.summary()

In [None]:
model.load_weights('/kaggle/input/extend-model/2016_extend_best_model.h5') 

In [None]:
from PIL import Image
import cv2
for batch, val_path in zip(validation_dataset, val_paths):
    img, gt_mask = batch
    img = img.numpy()
    gt_mask = gt_mask.numpy()
    pr_mask = model.predict(img).round()
    
    visualize(
        img=denormalize(img.squeeze()),
        gt_mask=gt_mask[..., 0].squeeze(),
        pr_mask=pr_mask[..., 0].squeeze(),
    )
    pr_mask = np.reshape(pr_mask[0], (224, 224))
    mask = Image.fromarray((pr_mask*255).astype(np.uint8), mode = 'L')
    print(np.unique((pr_mask*255).astype(np.uint8)))
    mask.save(str(val_path[76:]))