<a href="https://colab.research.google.com/github/CasCard/RIG-Works/blob/main/unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow
import keras
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Dense ,Conv2D ,Conv2DTranspose ,Dropout ,BatchNormalization ,Input ,MaxPooling2D ,concatenate,UpSampling2D
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from google.colab import drive
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import random
import os
from tensorflow.keras.optimizers import Adam
from google.colab.patches import cv2_imshow

In [2]:
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
def unet(input_size = (256,256,3)):
    input_layer = Input((256,256,3))
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_layer)
    #conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    #conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    #conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    #conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    #conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    #conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    #conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up8)
    #conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    #merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up9)
    #conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    #conv9 = Conv2D(3, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv10 = Conv2D(3, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs = [input_layer], outputs = [conv10])

    
    #model.summary()


    return model

model = unet()


In [7]:
import tensorflow as tf
def loss_fn(y_true,y_pred):
    mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM)
    k1 = mse(y_true, y_pred)
    k1 = (k1/(256*256))

    mae = tf.keras.losses.MeanAbsoluteError()
    k = mae(y_true, y_pred)
    k = (k/(256*256))
    ssim = tf.image.ssim(y_true *255 , y_pred * 255, max_val = 255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
    m = (1-ssim)/2
    return ((0.1*k1) + (0.6*k) + (0.3*m))
model.compile(optimizer = Adam(lr = 1e-4), loss =loss_fn , metrics = [tf.keras.metrics.MeanSquaredError()])


In [5]:
def get_result(image ,plot = True):

    image = image.reshape((-1,256,256 ,3))
    image_ = ((image/255.0))
    raw_result = model.predict(image_)
    result = raw_result.reshape((256,256,3))
    result = 255*((result))
    return result

In [None]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [9]:
from google.colab.patches import cv2_imshow
import cv2
import numpy as np

video1 = '/content/gdrive/My Drive/auv_files/redrod2.mp4'
cap = cv2.VideoCapture(video1)
ret,frame=cap.read()
out = cv2.VideoWriter('/content/gdrive/My Drive/auv_files/output_unet.avi',cv2.VideoWriter_fourcc('M','J','P','G'), 25, (128,128))
while cap.isOpened():
    ret, frame = cap.read()
    if np.shape(frame) != ():
      frame = cv2.resize(frame,(256,256))
      model = unet()
      model.load_weights("/content/gdrive/My Drive/auv_files/weights/weights_unet_2.h5")
      img_ = get_result(frame ,plot = False)
      img_ = np.uint8(img_)
      out.write(img_)
    # check for successfulness of cap.read()
    if not ret: break
    
    # img = cv2.resize(img, (1400, 1000), interpolation=cv2.INTER_AREA)
    # print(img.shape)
      
    # print("1")
    # print(img_)

    # Save the video
    # img_ = np.uint8(255 * img_)
    
    #cv2_imshow(img_)

    # cv2_imshow(image) # Note cv2_imshow, not cv2.imshow

    # cv2.waitKey()

# cv2.destroyAllWindows()
cap.release()
out.release()

