In [17]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import time
import pickle
import cv2
import pynvml 
from tensorflow.python.client import device_lib
import math
import pandas as pd
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 12758827889516579140,
 name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 22881945600
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 11139578637269181262
 physical_device_desc: "device: 0, name: GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6"]

In [6]:
#U-NET模型构建
#对卷积层进行轻量化
"""
#way1:SqueezeNet

def Conv2D_Light(filters,kernel_size,inputs,padding='same',activation='relu'):
    x  = tf.keras.layers.Conv2D(int(filters/8), 1, padding='same', activation='relu')(inputs)
    x1 = tf.keras.layers.Conv2D(int(filters/2), 1, padding='same', activation='relu')(x)
    x2 = tf.keras.layers.Conv2D(int(filters/2), 3, padding='same', activation='relu')(x)
    x  = tf.concat([x1, x2], axis=-1)
    return x
 
""" 
"""
#way2:分离卷积

def Conv2D_Light(filters,kernel_size,inputs,padding='same',activation='relu'):
    x  = tf.keras.layers.SeparableConv2D(filters, 3, padding='same', activation='relu')(inputs)
    return x

"""

#way3:正常卷积
def Conv2D_Light(filters,kernel_size,inputs,padding='same',activation='relu'):
    x  = tf.keras.layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    return x

""" 
#way4:Self-Attention （通道自注意+空间自注意）
def Conv2D_Light(filters,kernel_size,inputs,padding='same',activation='relu'):
    x = tf.keras.layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    #通道注意
    #通道attention
    maxpool_channel = tf.reduce_max(tf.reduce_max(x,axis=1,keepdims=True),axis=2,keepdims=True)
    avgpool_channel = tf.reduce_mean(tf.reduce_mean(x,axis=1,keepdims=True),axis=2,keepdims=True)
    channel_attention = tf.nn.sigmoid(maxpool_channel+avgpool_channel)
    x1 = x * channel_attention
    #空间attention
    maxpool_spatial=tf.reduce_max(inputs,axis=3,keepdims=True)
    avgpool_spatial=tf.reduce_mean(inputs,axis=3,keepdims=True)
    max_avg_pool_spatial=tf.concat([maxpool_spatial,avgpool_spatial],axis=3)
    spatial_attention = tf.keras.layers.Conv2D(1, 3, padding='same', activation='relu')(max_avg_pool_spatial)
    x2 = x1 * spatial_attention
    return x2
""" 

def W_Net():
    
    inputs = tf.keras.layers.Input(shape=(720,1280,3),dtype=tf.float32)
   
    x1 = Conv2D_Light(32, 3, inputs, padding='same', activation='relu')
    x1_M = tf.keras.layers.MaxPooling2D(padding='same')(x1)
    #H/2,W/2
    
    x2 = Conv2D_Light(64, 3, x1_M, padding='same', activation='relu')
    x2_B = tf.keras.layers.BatchNormalization()(x2)  
    x2_M = tf.keras.layers.MaxPooling2D(padding='same')(x2_B)
    #H/4,W/4
   
    x3 = Conv2D_Light(128, 3, x2_M, padding='same', activation='relu')
    x3_B = tf.keras.layers.BatchNormalization()(x3)       
    x3_M = tf.keras.layers.MaxPooling2D(padding='same')(x3_B)
    #H/8,W/8
    
    x4 = Conv2D_Light(256, 3, x3_M, padding='same', activation='relu')
    x4_B = tf.keras.layers.BatchNormalization()(x4) 
    x4_M = tf.keras.layers.MaxPooling2D(padding='same')(x4_B)
    #H/16,W/16
    
    x5 = Conv2D_Light(256, 3, x4_M, padding='same', activation='relu')
    
    # 上采样(4*4)
    x5_UP = tf.keras.layers.Conv2DTranspose(128, 4, strides=4, padding='same',activation='relu')(x5)
    x5_UP = tf.keras.layers.BatchNormalization()(x5_UP)
    
    x5_UP = tf.image.resize(x5_UP, (x3.shape[1], x3.shape[2]))
    x6 = tf.concat([x3, x5_UP], axis=-1)
    #H/4,W/4
    
    x6 = Conv2D_Light(128, 3, x6, padding='same', activation='relu')
    x6 = tf.keras.layers.BatchNormalization()(x6)
    
    # 上采样(4*4)
    x7 = tf.keras.layers.Conv2DTranspose(32, 4, strides=4, padding='same',activation='relu')(x6)
    x7 = tf.keras.layers.BatchNormalization()(x7)
    
    x7 = tf.image.resize(x7, (x1.shape[1], x1.shape[2]))
    x8 = tf.concat([x7, x1], axis=-1)
    #H,W
    
    x8 = Conv2D_Light(64, 3, x8, padding='same', activation='relu')
    x8 = tf.keras.layers.BatchNormalization()(x8)
    x8 = Conv2D_Light(64, 3, x8, padding='same', activation='relu')
    x8 = tf.keras.layers.BatchNormalization()(x8)
    
    # 下采样(16*16)
    x9 = tf.keras.layers.MaxPooling2D(pool_size=(16, 16),padding='same')(x8)
    x9 = tf.keras.layers.BatchNormalization()(x9)
    #H/16,W/16
    
    x9 = Conv2D_Light(256, 3, x9, padding='same', activation='relu')
    x9 = tf.image.resize(x9, (x5.shape[1], x5.shape[2]))
    x10 = tf.concat([x9, x5], axis=-1)
    x10 = Conv2D_Light(256, 3, x10, padding='same', activation='relu')
    x10 = tf.keras.layers.BatchNormalization()(x10)    
    
    # 上采样(2*2)
    x11 = tf.keras.layers.Conv2DTranspose(128, 2, strides=2, padding='same',activation='relu')(x10)
    x11 = tf.keras.layers.BatchNormalization()(x11)
    #H/8,W/8

    x11 = tf.image.resize(x11, (x4.shape[1], x4.shape[2]))
    x12 = tf.concat([x11, x4], axis=-1)
    
    x12 = Conv2D_Light(128, 3, x12, padding='same', activation='relu')
    x12 = tf.keras.layers.BatchNormalization()(x12)
    
    # 上采样(2*2)
    x13 = tf.keras.layers.Conv2DTranspose(128, 2, strides=2, padding='same',activation='relu')(x12)
    x13 = tf.keras.layers.BatchNormalization()(x13)
    #H/4,W/4
    
    x15 = tf.concat([x13, x3], axis=-1)
    x15 = Conv2D_Light(128, 3, x15, padding='same', activation='relu')
    x15 = tf.keras.layers.BatchNormalization()(x15)
    
    # 上采样(2*2)
    x16 = tf.keras.layers.Conv2DTranspose(64, 2, strides=2, padding='same',activation='relu')(x15)
    x16 = tf.keras.layers.BatchNormalization()(x16)
    #H/2,W/2
    
    x17 = tf.concat([x16, x2], axis=-1)
    x17 = Conv2D_Light(64, 3, x17, padding='same', activation='relu')
    x17 = tf.keras.layers.BatchNormalization()(x17)
    
    # 上采样(2*2)
    x18 = tf.keras.layers.Conv2DTranspose(32, 2, strides=2, padding='same',activation='relu')(x17)
    x18 = tf.keras.layers.BatchNormalization()(x18)
    #H,W
    x19 = tf.concat([x18, x1], axis=-1)
    x19 = Conv2D_Light(32, 3, x19, padding='same', activation='relu')
    x19= tf.keras.layers.BatchNormalization()(x19)
    
    output = tf.keras.layers.Conv2D(1, 1, padding='same',activation='sigmoid')(x19)
    
    return tf.keras.Model(inputs=inputs, outputs=output)

model = W_Net()
model. summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 720, 1280, 3 0                                            
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 720, 1280, 32 896         input_5[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_20 (MaxPooling2D) (None, 360, 640, 32) 0           conv2d_58[0][0]                  
__________________________________________________________________________________________________
conv2d_59 (Conv2D)              (None, 360, 640, 64) 18496       max_pooling2d_20[0][0]           
____________________________________________________________________________________________