In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
from tensorflow.keras import layers
tf.config.experimental_run_functions_eagerly(True)

In [None]:
class Double_conv(layers.Layer):
  def __init__(self,n_channels):
    super(Double_conv,self).__init__()
    self.conv1 = layers.Conv2D(n_channels,(3,3),padding='same')
    self.bn = layers.BatchNormalization()
    self.act = layers.Activation('relu')
    self.conv2 = layers.Conv2D(n_channels,(3,3),padding='same')

  def call(self,input_dime,training=False):

    x = self.conv1(input_dime,training=training)
    x = self.bn(x)
    x = self.act(x)
    x = self.conv2(x,training=training)
    x = self.bn(x)
    x = self.act(x)

    return x

In [None]:
class convtranspose2d(layers.Layer):
  def __init__(self,n_channels):
    super(convtranspose2d,self).__init__()
    self.convt = layers.Conv2DTranspose(n_channels,kernel_size=(2,2),padding='same')

  def call(self,input_tensor,training=False):
    x = self.convt(input_tensor,training=training)
    return x

In [None]:
def crp_img(tensor,target):
  x = tensor.get_shape().as_list()
  t_h = x[1]
  y = target.get_shape().as_list()
  tg_h= y[1]
  delta = t_h - tg_h 
  delta = delta // 2 
  #target = tf.image.resize_with_pad(target,t_h,t_w)

  return tensor[:,delta:t_h-delta,delta:t_h-delta,:] 

In [None]:
class Unet(keras.Model):
  def __init__(self,input_dim):
    super(Unet,self).__init__()
    #Initialization ENCODER
    #self.input = layers.Input(shape=(640,640,3),batch_size=1)
    self.down1 = Double_conv(32)
    self.down2 = Double_conv(64)
    self.down3 = Double_conv(128)
    self.down4 = Double_conv(256)
    self.down5 = Double_conv(512)
    self.down6 = Double_conv(1024)
    self.maxp  = layers.MaxPooling2D((2,2),strides=2)
    self.dout  = layers.Dropout(0.1)
    
    #Initialization Decoder
    self.up1  = convtranspose2d(512)
    self.up2  = convtranspose2d(256)
    self.up3  = convtranspose2d(128)
    self.up4  = convtranspose2d(64)
    self.up5  = convtranspose2d(32)
    self.up6  = convtranspose2d(16)
    self.dc1  = Double_conv(512)
    self.dc2  = Double_conv(256)
    self.dc3  = Double_conv(128)
    self.dc4  = Double_conv(64)
    self.dc5  = Double_conv(32)
    self.dc6  = Double_conv(16)
  def call(self,input_dim,training=False):
    #Encoder
    d0 = self.down1(input_dim,training=training)
    d1 = self.maxp(d0)
    d2 = self.dout(d1)
    #print(d2.shape)
    d3 = self.down2(d2,training=training)
    d4 = self.maxp(d3)
    d5 = self.dout(d4)
    #print(d5.shape)
    d6 = self.down3(d5,training=training)
    d7 = self.maxp(d6)
    d8 = self.dout(d7)
    #print(d8.shape)
    d9 = self.down4(d8,training=training)
    d10 = self.maxp(d9)
    d11 = self.dout(d10)
    #print(d11.shape)
    d12 = self.down5(d11,training=training)
    d13 = self.maxp(d12)
    d14 = self.dout(d13)
    #print(d14.shape)
    d15 = self.down6(d14,training=training)
    d16 = self.dout(d15)
    #print(d16.shape)


    #Decoder Shape
    u1 = self.up1(d15)
    u_1 = crp_img(d12,u1)
    up_1 = layers.concatenate([u1,u_1])
    ux =  self.dc1(up_1)
    #print(ux.shape)

    u2 = self.up2(d12)
    u_2 = crp_img(d9,u2)
    up_2 = layers.concatenate([u2,u_2])
    ux2 =  self.dc2(up_2)
    #print(ux2.shape)

    u3 = self.up3(d9)
    u_3 = crp_img(d6,u3)
    up_3 = layers.concatenate([u3,u_3])
    ux3 =  self.dc3(up_3)
    #print(ux3.shape)

    u4 = self.up4(d6)
    u_4 = crp_img(d3,u4)
    up_4 = layers.concatenate([u4,u_4])
    ux4 =  self.dc4(up_4)
    #print(ux4.shape)

    u5 = self.up5(d3)
    u_5 = crp_img(d0,u5)
    up_5 = layers.concatenate([u5,u_5])
    ux5 = self.dc5(up_5)
    #print(ux5.shape)

    u6 = self.up6(d0)
    u_6 = crp_img(input_dim,u6)
    up_6 = layers.concatenate([u6,u_6])
    ux6 = self.dc6(up_6)
    #print(ux6.shape)


    z = layers.Conv2D(1,(1,1),padding='same',activation='sigmoid')(ux6)
    #print(z.shape)
    return z

In [None]:
z = Unet(input_dim=(640,640,3))

In [None]:
#z.build([1,640,640,3])

In [None]:
#z.summary()

In [None]:
my_callbacks = [tf.keras.callbacks.EarlyStopping(patience=5,monitor='val loss'),tf.keras.callbacks.ModelCheckpoint("check.h5",monitor='val_acc',save_weights_only=True,verbose=1)]

In [None]:
import os 
from glob import glob
from google.colab import drive
drive.mount('/content/drive')




Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_images = sorted(glob('/content/drive/MyDrive/Data_check/images/train/*'))
train_mask = sorted(glob('/content/drive/MyDrive/Data_check/masks/train/*'))
test_images = sorted(glob('/content/drive/MyDrive/Data_check/images/test/*'))
test_mask = sorted(glob('/content/drive/MyDrive/Data_check/masks/test/*'))

In [None]:
def read_data(img,msk):
  
  img_str = tf.io.read_file(img)
  img_decode = tf.image.decode_png(img_str,channels=3)

  mask_str = tf.io.read_file(msk)
  mask_decode = tf.image.decode_png(mask_str,channels=1)

  return img_decode,mask_decode

def resize_normalize(image,mask):

  img = tf.image.resize_with_pad(image,640,640)
  img = tf.cast(img,tf.float32)/255.0
  mask = tf.image.resize_with_pad(mask,640,640)
  mask = tf.cast(mask,tf.float32) / 255.0

  return img,mask

def get_data(image,mask):
  im,ma = read_data(image,mask)
  imx,max = resize_normalize(im,ma)

  return imx,max

In [None]:
train_set = tf.data.Dataset.from_tensor_slices((train_images,train_mask))
val_set = tf.data.Dataset.from_tensor_slices((test_images,test_mask))
epochs = 30
bs = 1
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_set = train_set.map(get_data,num_parallel_calls=AUTOTUNE)
train_set = train_set.cache()
train_set = train_set.batch(bs)
train_set = train_set.prefetch(AUTOTUNE)
val_set = val_set.map(get_data,num_parallel_calls=AUTOTUNE)
val_set = val_set.cache()
val_set = val_set.batch(bs)
val_set = val_set.prefetch(AUTOTUNE)

  "Even though the `tf.config.experimental_run_functions_eagerly` "


In [None]:
z.compile(optimizer=keras.optimizers.Adam(lr=0.0003),loss=keras.losses.BinaryCrossentropy(from_logits=False),metrics=["accuracy"])

  "The `lr` argument is deprecated, use `learning_rate` instead.")


In [None]:
z.fit(train_set,epochs=epochs,steps_per_epoch=len(train_images)//bs,verbose=1,validation_data=val_set,callbacks=my_callbacks)