In [36]:
import gdal
import ogr 

ds = ogr.Open("./test_chips.shp")
layer = ds.GetLayer()

test_chips_coords = []

for i in range(0,250):

    feature = layer.GetFeature(i)

    coords = list(feature)[1:5]

    test_chips_coords.append(dict([("ulx", coords[0]), ("uly", coords[1]), ("lrx", coords[2]), ("lry", coords[3])]))

print(test_chips_coords[0]["ulx"])

-56.61951136331


In [30]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Input
from tensorflow.keras import Model
from tensorflow.keras import models
from tensorflow.keras import metrics
from tensorflow.keras import optimizers
from tensorflow.keras import losses
from tensorflow.keras import regularizers
from tensorflow.keras import initializers
from tensorflow.keras import backend as K
from tensorflow.keras import callbacks

def parse_data(data):
    
    features_for_data = {
        'b1': tf.io.FixedLenFeature([256,256], tf.float32),
        'b2': tf.io.FixedLenFeature([256,256], tf.float32),
        'b3':tf.io.FixedLenFeature([256,256], tf.float32),
        'ref': tf.io.FixedLenFeature([256,256], tf.float32)
    }
    
    data = tf.io.parse_single_example(data, features_for_data)
    inputsList_data = [data.get(key) for key in ['b1', 'b2', 'b3']]
    stacked_data = tf.stack(inputsList_data, axis=0)
    stacked_data = tf.transpose(stacked_data, [1,2,0])

    inputsList_ref =  [data.get(key) for key in ['ref']]
    stacked_ref = tf.stack(inputsList_ref, axis=0)
    stacked_ref = tf.transpose(stacked_ref, [1,2,0])

    return stacked_data[:,:,:3], stacked_ref[:,:,:]

def get_data():

    training_files = "../data/training_buffered_2x.tfrecord"
    validation_files = "../data/validation_buffered_2x.tfrecord"
    test_files = "../data/test_buffered_2x.tfrecord"
    
    files_set = [training_files, validation_files, test_files]
    data_sets = []
    
    for i in files_set:
        files = tf.io.gfile.glob(i)
        files = tf.data.TFRecordDataset(files, compression_type=None)
        
        data_sets.append(files.map(parse_data))
        
    return data_sets[0], data_sets[1], data_sets[2]

def soft_dice_loss(y_pred, y_true, smooth = 1):
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
   
    intersection = K.sum(y_true_f * y_pred_f)
    dice = K.abs(2. * intersection + smooth) / (K.abs(K.sum(K.square(y_true_f))) + K.abs(K.sum(K.square(y_pred_f))) + smooth)
    
    return 1-K.mean(dice)

def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

best_model_soft_dice = tf.keras.models.load_model("../estrutura_modelo", custom_objects={'recall_m': recall_m, 'f1_m': f1_m, 'precision_m': precision_m, 'loss function': soft_dice_loss}, compile=False)
best_model_soft_dice.compile(loss=soft_dice_loss, metrics=[recall_m, f1_m, precision_m])

training_set, validation_set, test_set = get_data()



In [15]:
import gdal
import osr
import gdalconst
import glob
import time
import numpy as np

from skimage.morphology import skeletonize, binary_closing, rectangle
from skimage.transform import rotate
from PIL import Image
from tqdm.auto import tqdm

from grass_session import Session
import grass.script.core as gs

def water_masks(predicted_images, water_raster_path):
    
    print("Etapa de divisão do mosaico de água em cartas")
    for i in tqdm(predicted_images):
        pred = gdal.Open(i)
        # print(pred)
        xsize = pred.RasterXSize
        ysize = pred.RasterYSize
        
        ulx, xres, xskew, uly, yskew, yres = pred.GetGeoTransform()
        lrx = ulx + (xsize * xres)
        lry = uly + (ysize * yres)
        
        carta_name = i.split("/")[-1].split("_")[0]
        
        outfile = "./mascara_de_agua/{}_agua.tif".format(carta_name)
        
        gdal.Translate(outfile, water_raster_path, projWin=[ulx,uly,lrx,lry], width=xsize, height=ysize, creationOptions=['TFW=YES', 'COMPRESS=LZW'])
        
        pred = None
        
    
def water_masks_application(predicted_images, water_masks):
    
    print("Etapa de aplicação da máscara de água")
    for x,y in tqdm(zip(predicted_images, water_masks)):
        print(x)
        print(y)
        carta_name = x.split("/")[-1].split("_")[0]
        # print(carta_name)
        outfile = "./mascarados/{}_masked.tif".format(carta_name)
 
        os.system('python C:/Users/jonas/anaconda3/envs/research/Scripts/gdal_calc.py -A {x} -B {y} --outfile={outfile} --NoDataValue=0 --calc="A-(B>0)" --creation-option "TFW=YES" --creation-option "COMPRESS=LZW" --quiet --overwrite')

        # startcmd=["python3.exe", "gdal_calc.py", "-A", x, "-B", y, "--outfile" ,outfile, "--NoDataValue", "0", "--calc", "A-(B>0)", "--creation-option", 'TFW=YES', "--creation-option", 'COMPRESS=LZW', "--quiet", "--overwrite"]

        # subprocess.call(startcmd)

        pred = None
        water = None
        
def skeletonize_tif(masked):
    
    print("Etapa de snapping e skeletização do dados do raster")
    for i in tqdm(masked):
        mascarado = gdal.Open(i)
        
        carta_name = i.split("/")[-1].split("_")[0]
        
        raster_array = mascarado.GetRasterBand(1).ReadAsArray()
        
        xsize = mascarado.RasterXSize + 1
        ysize = mascarado.RasterYSize + 1
        
        ulx, xres, xskew, uly, yskew, yres = mascarado.GetGeoTransform()
        lrx = ulx + (xsize * xres)
        lrx = uly + (ysize * yres)
        
        masked_skeleton = skeletonize(raster_array).astype(np.uint8)
        
        dilation = binary_closing(masked_skeleton, rectangle(10,1))
        dilation = binary_closing(dilation, rectangle(1,10))
        dilation = binary_closing(dilation, rotate(rectangle(1,10), 45))
        dilation = binary_closing(dilation, rotate(rectangle(1,10), 135))
        
        dilation = binary_closing(dilation, rotate(rectangle(10,1), 45))
        dilation = binary_closing(dilation, rotate(rectangle(10,1), 135))
        
        array_skeleton = skeletonize(dilation).astype(np.uint8)
        
        output_raster = gdal.GetDriverByName('GTiff').Create('./skeleton/skeleton_{}.tif'.format(carta_name), xsize, ysize, 1, gdalconst.GDT_Byte)
        output_raster.SetGeoTransform(mascarado.GetGeoTransform())
        output_raster.GetRasterBand(1).WriteArray(array_skeleton)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        
        output_raster.SetProjection(srs.ExportToWkt())
        
        output_raster.FlushCache()
        
        mascarado = None
    
def vectorize(raster, carta_name):
    
    gs.run_command("r.in.gdal", input=raster, output="chart_raster", overwrite=True)
    
    gs.run_command("g.region", raster="chart_raster", overwrite=True)
    
    gs.run_command("r.null", map="chart_raster", setnull=[0])
    gs.run_command("r.thin", input="chart_raster", output="chart_thinned", iteration=10, overwrite=True)
    gs.run_command("r.to.vect", input="chart_thinned", output="chart_vector", type="line", overwrite=True)
    
    #Limpeza 1 -> Linhas que se tocam e segmentos maiores que 1km
    gs.run_command("v.generalize", input="chart_vector", output="chart_vector_smoothed", method="snake", iterations=5, threshold=0.5, angle=180, overwrite=True)
    gs.run_command("v.build", map="chart_vector_smoothed", error="error", option="build", overwrite=True)
    gs.run_command("v.to.db", map="chart_vector_smoothed", option="length", type="line", columns="len", units="me")
    gs.run_command("v.select", ainput="chart_vector_smoothed", atype="line", binput="chart_vector_smoothed", btype="line", output="connected_lines", operator="touches", overwrite=True)
    gs.run_command("v.edit", map="chart_vector_smoothed", tool="delete", type="line", where="len < 1000")
    gs.run_command("v.patch", input=["chart_vector_smoothed", "connected_lines"], output="filtered_chart", overwrite=True)
    
    #Limpeza 2 -> Borda das cartas
    gs.run_command("v.select", ainput="charts", atype="area", binput="filtered_chart", btype="line", output="chart", operator="contains", overwrite=True)
    gs.run_command("v.to.lines", input="chart", output="chart_limits", overwrite=True)
    gs.run_command("v.buffer", input="chart_limits", distance=0.0005, output="chart_limits_buffered", overwrite=True)
    gs.run_command("v.select", ainput="filtered_chart", atype="line", binput="chart_limits_buffered", btype="area", output="border_lines", operator="within", overwrite=True)
    gs.run_command("v.select", ainput="filtered_chart", atype="line", binput="border_lines", operator="equals", output="chart_border_cleaned", overwrite=True, flags="r")
    
    #Limpeza 3 -> Borda das cenas
    gs.run_command("v.select", ainput="scenes", atype="area", binput="chart_border_cleaned", btype="line", output="scenes_intersect", operator="intersects", overwrite=True)
    gs.run_command("v.select", ainput="chart_border_cleaned", atype="line", binput="scenes_intersect", btype="area", output="scenes_lines", operator="within", overwrite=True)
#     gs.run_command("v.build", map="scenes_lines", error="error", option="build", overwrite=True)
#     try:
# #         print(gs.parse_command(""))
#         print(gs.parse_command("db.tables", flags="p"))
    gs.run_command("v.db.addtable", map="scenes_lines")
    gs.run_command("v.db.connect", map="scenes_lines", table="scenes_lines", flags="o")
#         gs.run_command("v.db.addcolumn", map="scenes_lines", columns="angle double precision, len double precision")
#         gs.run_command("v.to.db", map="scenes_lines", option="length", type="line", columns="len", units="me", overwrite=True)
#         gs.run_command("v.to.db", map="scenes_lines", option="azimuth", type="line", columns="angle", units="degrees", overwrite=True)
#     except:
#     print(gs.parse_command("db.tables", flags="p"))
    gs.run_command("v.to.db", map="scenes_lines", option="length", type="line", columns="len", units="me", overwrite=True)
    gs.run_command("v.to.db", map="scenes_lines", option="azimuth", type="line", columns="angle", units="degrees", overwrite=True)
    gs.run_command("v.extract", input="scenes_lines", type="line", where="((angle >= 89 and angle <= 91 ) or (angle >= -1 and angle <= 1) or (angle >= 179 and angle <= 181) or (angle >= 269 and angle <= 271) or (angle >= 359 and angle <= 361)) and (len > 500)", output="scenes_lines_filtered", overwrite=True)
#     gs.run_command("v.db.addtable", map="scenes_lines_filtered")
#     gs.run_command("v.db.connect", map="scenes_lines_filtered", table="scenes_lines_filtered", flags="o")
#     gs.run_command("v.extract", input="scenes_lines_filtered", type="line", where="len > 500", output="scenes_lines_filtered_2", overwrite=True)
    gs.run_command("v.select", ainput="chart_border_cleaned", atype="line", binput="scenes_lines_filtered", operator="equals", output="chart_border_cleaned_scenes", overwrite=True, flags="r")

    #Limpeza 4 -> Densidade de estradas
    gs.run_command("v.mkgrid", map="grid", position="region", box=[0.0901404892, 0.0901404892], overwrite=True)
    gs.run_command("v.overlay", ainput="chart_border_cleaned_scenes", atype="line", binput="grid", btype="area", operator="and", output="lines_divided", overwrite=True)
    gs.run_command("v.overlay", ainput="chart_border_cleaned_scenes", atype="line", binput="grid", btype="area", operator="not", output="lines_outside_grid", overwrite=True)
    gs.run_command("v.to.db", map="lines_divided", option="length", type="line", columns="len", units="me")

    try:
        gs.run_command("db.execute", driver="sqlite", sql="DROP TABLE grid_density")
    except:
        print("Não precisa deletar tabela")
    gs.run_command("db.execute", driver="sqlite", sql="CREATE TABLE grid_density as SELECT a.cat, CAST(SUM(b.len) as double precision) as density FROM grid a, lines_divided b WHERE a.cat = b.b_cat group by a.cat")
    gs.run_command("v.db.join", map="grid", column="cat", other_table="grid_density", other_column="cat")
    gs.run_command("v.extract", input="grid", type="area", where="density > 10000", output="grid_chosen", overwrite=True)
    gs.run_command("v.select", ainput="lines_divided", atype="line", binput="grid_chosen", operator="intersects", output="chart_filtered", overwrite=True)
    gs.run_command("v.patch", input=["chart_filtered", "lines_outside_grid"], output="chart_cleaned", overwrite=True)
    gs.run_command("v.out.ogr", input="chart_cleaned", output="./vetor/{}_vector.shp".format(carta_name), format="ESRI_Shapefile", overwrite=True)
    
    
def rasterize(raster, carta_name):
    
    outfile_10 = "./raster/raster_10/{}_raster_10.tif".format(carta_name)
    outfile_1000 = "./raster/raster_1000/{}_raster_1000.tif".format(carta_name)
    
    !gdal_rasterize -burn 1 -of GTiff -a_nodata 0 -co 'TFW=YES' -co 'COMPRESS=LZW' -tr 9.014048920182702478e-05 9.014048920182702478e-05 {i} {outfile_10}
    !gdal_rasterize -burn 1 -of GTiff -a_nodata 0 -co 'TFW=YES' -co 'COMPRESS=LZW' -tr 0.00901404892 0.00901404892 {i} {outfile_1000}
    
def mosaics(vectors, rasters_10, raster_1000):
    
#     gdal.BuildVRT("./mosaicos/mosaic_raster_10.vrt", rasters_10)
#     gdal.BuildVRT("./mosaicos/mosaic_raster_1000.vrt", raster_1000)
    
#     gdal.Translate("./mosaicos/mosaic_raster_10.tif", "./mosaicos/mosaic_raster_10.vrt", outputType=gdalconst.GDT_Byte, creationOptions=['TFW=YES', 'COMPRESS=LZW'])
#     gdal.Translate("./mosaicos/mosaic_raster_1000.tif", "./mosaicos/mosaic_raster_1000.vrt", outputType=gdalconst.GDT_Byte, creationOptions=['TFW=YES', 'COMPRESS=LZW'])
    
    !ogrmerge.py -single -o ./mosaicos/mosaic_vector_amazon.shp {" ".join(vectors)} 
    
    
if __name__ == "__main__":
    
#     Primeira Etapa - Divisão do mosaico de água por cartas
    predicted_images = [glob.glob("./predicted/*.tif")[0].replace("\\", "/")]
    print(predicted_images)
    # skeleton_fixed_list = ["./skeleton/skeleton_" + i.split("/")[-1].split("_")[0] + ".tif" for i in ['SB-18-Z-A', 'SB-19-Y-A', 'NB-20-Y-D', 'SB-19-Z-C', 'SC-19-Y-A']]
    
    water_raster_path = "./auxiliar/aguas_amazonia_buffered.tif"
    water_masks(predicted_images, water_raster_path)
    
#     Segunda Etapa - Aplicação da máscara de água
    water_masked = [glob.glob("./mascara_de_agua/*.tif")[0].replace("\\", "/")]
    print(water_masked)
    # water_masked_fixed = [i for i in water_masked if i is in ]
    water_masks_application(predicted_images, water_masked)
    # water_masks_application(predicted_images, water_fixed_list)
    
#     Terceira Etapa - Snapping e Skeletização
    masked = [glob.glob("./mascarados/*.tif")[0].replace("\\", "/")]
    print(masked)
    skeletonize_tif(masked)
    
#     #Quarta Etapa - Vetorização
    skeletons = glob.glob("./skeleton/*.tif")
    charts = "./auxiliar/Amazonia_legal_cartas.shp"
    scenes = "./auxiliar/grid_sentinel_amazonia_legal_certo.geojson"
    
    with Session(gisdb="/tmp", location="amazonia", create_opts="EPSG:4326"):
        
        gs.set_raise_on_error(True)
        gs.set_capture_stderr(True)
        
        print("Etapa de vetorização dos dados")
                              
        for i in tqdm(skeletons):
            carta_name = i.split("/")[-1].split("_")[-1].split(".")[0]

            gs.run_command("g.remove", flags="f", type="all", pattern="tmp*")
            gs.run_command("v.in.ogr", input=charts, output="charts", overwrite=True)
            gs.run_command("v.in.ogr", input=scenes, output="scenes", overwrite=True)

            vectorize(i, carta_name)
            
    # vectors = glob.glob("./vetor/*.shp")
#     print(len(vectors))

#     #Quinta Etapa - Rasterização
#     print("Etapa de rasterização dos vetores")
#     for i in tqdm(vectors):
#         carta_name = i.split("/")[-1].split("_")[0]
#         rasterize(i, carta_name)
        
    # raster_10 = glob.glob("./raster/raster_10/*.tif")
    # raster_1000 = glob.glob("./raster/raster_1000/*.tif")
    
    #Sexta Etapa - Composição dos mosaicos
    # print("Etapa de montagem dos mosaicos")
    # mosaics(vectors, raster_10, raster_1000)

['./predicted/Test_chip_0.tif']
Etapa de divisão do mosaico de água em cartas


100%|██████████| 1/1 [00:00<00:00, 14.35it/s]


['./mascara_de_agua/Test_agua.tif']
Etapa de aplicação da máscara de água


0it [00:00, ?it/s]

./predicted/Test_chip_0.tif
./mascara_de_agua/Test_agua.tif


1it [00:00,  1.40it/s]


['./mascarados/test_masked.tif']
Etapa de snapping e skeletização do dados do raster


100%|██████████| 1/1 [00:00<00:00, 66.71it/s]


Etapa de vetorização dos dados


  0%|          | 0/1 [00:00<?, ?it/s]


CalledModuleError: Module run None g.remove -f type=all pattern=tmp* ended with error
Process ended with non-zero return code 3221225781. See errors in the (error) output.

In [14]:
import os
import sys
import subprocess

grass7bin_win = 'C:/"Program Files"/"GRASS GIS 7.8"/grass78.bat'
grass7bin = grass7bin_win

os.environ["GRASSBIN"] = grass7bin_win

startcmd = grass7bin + ' --config path'

p = subprocess.Popen(startcmd, shell=True, 
                 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()

if p.returncode != 0:
 print (sys.stderr, 'ERROR: %s' % err)
 print (sys.stderr, "ERROR: Cannot find GRASS GIS 7.8 start script (%s)" % startcmd)
 sys.exit(-1)
gisbase = out.strip(b'\n\r')
gisbase = gisbase.decode("utf-8") 

In [13]:
os.environ['GISBASE'] = gisbase 
os.environ['PATH'] += os.pathsep + os.path.join(gisbase, 'extrabin')
home = os.path.expanduser("~")
os.environ['PATH'] += os.pathsep + os.path.join(home, '.grass7', 'addons', 'scripts')
gpydir = os.path.join(gisbase, "etc", "Python")
sys.path.append(gpydir)
gisdb = os.path.join(os.path.expanduser("~"), "Documents\grassdata")
os.environ['GISDBASE'] = gisdb

In [8]:
import os
import sys
import subprocess
# grass7bin = r'C:\Program Files\QGIS 3.20.2\bin\grass78.bat'
# 
# query GRASS 7 itself for its GISBASE
# startcmd = [grass7bin, '--config', 'path']

# p = subprocess.Popen(startcmd, shell=False,
#                      stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# out, err = p.communicate()
# if p.returncode != 0:
#     print >>sys.stderr, "ERROR: Cannot find GRASS GIS 7 start script (%s)" % startcmd
#     sys.exit(-1)
# print(out[0:-2])
# gisbase = out[0:-2]
# this could be replaced by using the right gisbase
# directly instead of the executable
# print(out)
# Set GISBASE environment variable
# os.environ['GISBASE'] = gisbase.decode()
os.environ['GRASSBIN'] = 'C:/"Program Files"/"GRASS GIS 7.8"/grass78.bat'
os.environ['GISBASE'] = 'C:/"Program Files"/"GRASS GIS 7.8"'
# define GRASS-Python environment
# gpydir = os.path.join(gisbase.decode(), "etc", "python")
# sys.path.append(gpydir)

# data
# gisdb = os.path.join(os.path.expanduser("~"), "Documents/grassdata")

# specify (existing) location and mapset
# location = "amazonia"
# mapset = "amz"


In [77]:
import osr
import numpy as np

def concat_infos(carta, predicted):
    coords = carta[0]
    print(coords)
    predicted[predicted >= 0.2] = 1
    predicted[predicted < 0.2] = 0
    
    return  predicted, coords

#FUNÇÃO DE GEO-REFERENCIAMENTO DOS RESULTADOS DO MODELO
def save_predicted(inf):
    saved_predicted = []
    # for i in inf:
        
    coords = inf[1]
    # print(i)
    xmin = coords["ulx"]
    ymax = coords["uly"]
#         xres = coords[2]
#         yres = coords[3]
    
    geotransform=(xmin,9.014048920182260666e-05,0,ymax,0, -9.014048920182260666e-05)   

    # output_raster = gdal.GetDriverByName('GTiff').Create('/vsimem/{}_{}.tif'.format(carta_name, num),256, 256, 1 ,gdal.GDT_Byte)  # Open the file
    output_raster = gdal.GetDriverByName('GTiff').Create('./{}.tif'.format("Test_chip_0"),256, 256, 1 ,gdal.GDT_Byte)  # Open the file
    
    output_raster.SetGeoTransform(geotransform)
    srs = osr.SpatialReference()                
    srs.ImportFromEPSG(4326)                    
    
    output_raster.SetProjection( srs.ExportToWkt() )  
    output_raster.GetRasterBand(1).WriteArray(np.reshape(inf[0], (256,256)))
    saved_predicted.append(output_raster)
    output_raster.FlushCache()

    return saved_predicted

In [78]:
predicted_soft_dice_1 = best_model_soft_dice.predict(test_set.skip(0).take(1).batch(1))

infos = concat_infos(test_chips_coords, predicted_soft_dice_1)
# print(infos[1])
save_predicted(infos)

{'ulx': -56.61951136331, 'uly': -9.83166231658, 'lrx': -56.596514492039994, 'lry': -9.854659187849999}


[<osgeo.gdal.Dataset; proxy of <Swig Object of type 'GDALDatasetShadow *' at 0x0000016517D80870> >]

In [None]:
import tensorflow.keras.preprocessing as prep
from skimage.morphology import skeletonize, binary_closing
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from IPython.display import display

font = {'family': 'serif',
        'color':  'black',
        'weight': 'normal',
        'size': 12,
}

for x in range(0,5):
    fig = plt.figure(figsize=(15,15))

    gs = GridSpec(1, 5, figure=fig)
    gs.update(wspace=0.1, hspace=0.02)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2])
    ax4 = fig.add_subplot(gs[0, 3])
#     ax5 = fig.add_subplot(gs[0, 4])


    for i in test_set.skip(x).take(1):
        
        img_original = prep.image.array_to_img(i[0])
        ax1.imshow(img_original)
        ax1.axis('off')
        if x == 0:
            ax1.set_title("(a)", fontdict=font)
        
        img_original = prep.image.array_to_img(i[0])
        ax2.imshow(img_original.split()[2], cmap=plt.cm.gray)
        ax2.axis('off')
        if x == 0:
            ax2.set_title("(b)", fontdict=font)

        img = prep.image.array_to_img(i[1])
        ax3.imshow(img, cmap=plt.cm.gray)
        ax3.axis('off')
        if x == 0:
            ax3.set_title("(c)", fontdict=font)


    predicted_soft_dice_1 = best_model_soft_dice.predict(test_set.skip(x).take(1).batch(1))

    # predicted_soft_dice_1[predicted_soft_dice_1 < 0.2] = 0
    # predicted_soft_dice_1[predicted_soft_dice_1 > 0.2] = 1
    predicted_soft_dice_1 = prep.image.array_to_img(predicted_soft_dice_1[0])

    ax4.imshow(predicted_soft_dice_1, cmap=plt.cm.gray)
    ax4.axis('off')
    if x == 0:
        ax4.set_title("(d)", fontdict=font)