In [1]:
import tensorflow as tf

In [2]:
def conv3otherRelu(filters, kernel_size=None, stride=None, padding=None):
    if kernel_size is None:
        kernel_size = 3
    if stride is None:
        stride = 1
    if padding is None:
        padding = 'same'

    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(filters, kernel_size, stride, padding, use_bias=True),
        tf.keras.layers.ReLU()
    ])

In [3]:
import tensorflow as tf

def l2_norm(x):
    norm = tf.norm(x, ord=2, axis=-2, keepdims=True)
    return tf.einsum("bcn, bn->bcn", x, 1 / norm)



#normalized_tensor = l2_norm(input_tensor)


In [4]:
import tensorflow as tf

class PAM_Module(tf.keras.layers.Layer):
    def __init__(self, in_places, scale=8, eps=1e-6):
        super(PAM_Module, self).__init__()
        self.gamma=tf.Variable(initial_value=tf.zeros((1,)), trainable=True)
        self.inplaces=in_places
        self.l2_norm=l2_norm
        self.eps=eps

        self.query_conv=tf.keras.layers.Conv2D(in_places//scale, kernel_size=1, padding='same')
        self.key_conv=tf.keras.layers.Conv2D(in_places//scale,  kernel_size=1, padding='same')
        self.value_conv=tf.keras.layers.Conv2D(in_places//scale,  kernel_size=1, padding='same')

    def call(self, x):
        batch_size, width, height, channels = x.shape
        print("Shape before query_conv:", x.shape)
        Q= tf.reshape(self.query_conv(x), (batch_size, -1, width * height))
        K= tf.reshape(self.key_conv(x), (batch_size, -1, width * height))
        V= tf.reshape(self.value_conv(x), (batch_size, -1, width * height))
        print('Q shape', Q.shape)
        print('K shape', K.shape)
        Q = self.l2_norm(tf.transpose(Q, perm=[0, 2, 1]))
        K = self.l2_norm(K)
        print('Q shape', Q.shape)
        print('K shape', K.shape)
        tailor_sum = 1 / (width * height + tf.einsum("bnc,bcm->bnm", Q, tf.reduce_sum(K, axis=-1) + self.eps))
        value_sum = tf.expand_dims(tf.reduce_sum(V, axis=2), axis=-1)
        value_sum = tf.tile(value_sum, [1, 1, width * height])

        matrix = tf.einsum('bmn, bcn->bmc', K, V)
        matrix_sum = value_sum + tf.einsum("bnm, bmc->bcn", Q, matrix)

        weight_value = tf.einsum("bcn, bn->bcn", matrix_sum, tailor_sum)
        weight_value = tf.reshape(weight_value, (batch_size, channels, height, width))

        return x + self.gamma * weight_value


In [5]:
import tensorflow as tf

class CAM_Module(tf.keras.layers.Layer):
    def __init__(self):
        super(CAM_Module, self).__init__()
        self.gamma = tf.Variable(initial_value=tf.zeros((1,)), trainable=True)
        self.softmax = tf.keras.layers.Softmax(axis=-1)

    def call(self, x):
        batch_size, height, width, channels = x.shape

        proj_query = tf.reshape(x, (batch_size, -1, channels))
        proj_key = tf.transpose(proj_query, perm=[0, 2, 1])
        energy = tf.matmul(proj_query, proj_key)
        energy_max = tf.reduce_max(energy, axis=-1, keepdims=True)
        energy_new = energy_max - energy
        attention = self.softmax(energy_new)

        proj_value = tf.reshape(x, (batch_size, -1, channels))

        out = tf.matmul(attention, proj_value)
        out = tf.reshape(out, (batch_size, height, width, channels))

        out = self.gamma * out + x
        return out


In [6]:
import tensorflow as tf

class PAM_CAM_Layer(tf.keras.layers.Layer):
    def __init__(self, in_ch):
        super(PAM_CAM_Layer, self).__init__()
        self.conv1 = conv3otherRelu(in_ch)

        self.PAM = PAM_Module(in_ch)
        self.CAM = CAM_Module()

        self.dropout1 = tf.keras.layers.Dropout(0.1)
        self.conv2P = conv3otherRelu(in_ch)
        self.dropout2 = tf.keras.layers.Dropout(0.1)
        self.conv2C = conv3otherRelu(in_ch)
        self.dropout3 = tf.keras.layers.Dropout(0.1)
        self.conv3 = conv3otherRelu(in_ch)

    def call(self, x):
        print('shape of feature map fed to attention', x.shape)
        x = self.conv1(x)
        x_pam = self.PAM(x)
        x_cam = self.CAM(x)
        x = self.dropout1(x_pam)
        x = self.conv2P(x)
        x = x + self.dropout2(x_cam)
        x = self.conv2C(x)
        x = self.dropout3(x)
        x = self.conv3(x)
        return x


In [None]:
import tensorflow as tf

class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, in_channels, n_filters):
        super(DecoderBlock, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(in_channels // 4, kernel_size=1, padding='same')
        self.norm1 = tf.keras.layers.BatchNormalization()
        self.relu1 = tf.keras.layers.ReLU()

        self.deconv2 = tf.keras.layers.Conv2DTranspose(in_channels // 4, kernel_size=3, strides=2, padding='same', output_padding=1)
        self.norm2 = tf.keras.layers.BatchNormalization()
        self.relu2 = tf.keras.layers.ReLU()

        self.conv3 = tf.keras.layers.Conv2D(n_filters, kernel_size=1, padding='same')
        self.norm3 = tf.keras.layers.BatchNormalization()
        self.relu3 = tf.keras.layers.ReLU()

    def call(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.deconv2(x)
        x = self.norm2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)
        return x


In [10]:
import tensorflow as tf

class MAResUNet(tf.keras.Model):
    def __init__(self, num_channels=3, num_classes=1):
        super(MAResUNet, self).__init__()

        filters = [256, 512, 1024, 2048]
        resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(None, None, num_channels))
        resnet.trainable = False
        #resnet.summary()
        self.firstconv = resnet.get_layer('conv1_conv')
        self.firstbn = resnet.get_layer('conv1_bn')
        self.firstrelu = resnet.get_layer('conv1_relu')
        self.firstmaxpool = resnet.get_layer('pool1_pool')
        self.encoder1 = resnet.get_layer('conv2_block1_out')
        self.encoder2 = resnet.get_layer('conv3_block4_out')
        self.encoder3 = resnet.get_layer('conv4_block6_out')
        self.encoder4 = resnet.get_layer('conv5_block3_out')

        self.attention4 = PAM_CAM_Layer(filters[3])
        self.attention3 = PAM_CAM_Layer(filters[2])
        self.attention2 = PAM_CAM_Layer(filters[1])
        self.attention1 = PAM_CAM_Layer(filters[0])

        self.decoder4 = DecoderBlock(filters[3], filters[2])
        self.decoder3 = DecoderBlock(filters[2], filters[1])
        self.decoder2 = DecoderBlock(filters[1], filters[0])
        self.decoder1 = DecoderBlock(filters[0], filters[0])

        self.finaldeconv1 = tf.keras.layers.Conv2DTranspose(32, 4, strides=2, padding='same')
        self.finalrelu1 = tf.keras.layers.Activation('relu')
        self.finalconv2 = tf.keras.layers.Conv2D(32, 3, padding='same')
        self.finalrelu2 = tf.keras.layers.Activation('relu')
        self.finalconv3 = tf.keras.layers.Conv2D(num_classes, 3, padding='same')

    def call(self, x):
        # Encoder
        x1 = self.firstconv(x)
        x1 = self.firstbn(x1)
        x1 = self.firstrelu(x1)
        x1 = self.firstmaxpool(x1)
        e1 = self.encoder1(x1)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        print('output from res',e4.shape)
        e4 = self.attention4(e4)

        # Decoder
        d4 = self.decoder4(e4) + self.attention3(e3)
        d3 = self.decoder3(d4) + self.attention2(e2)
        d2 = self.decoder2(d3) + self.attention1(e1)
        d1 = self.decoder1(d2)

        out = self.finaldeconv1(d1)
        out = self.finalrelu1(out)
        out = self.finalconv2(out)
        out = self.finalrelu2(out)
        out = self.finalconv3(out)

        return out


In [11]:
net = MAResUNet(3)
in_batch, inchannel, in_h, in_w = 10, 3, 512, 512
x = tf.random.normal((in_batch, in_h, in_w, inchannel))
out = net(x)

output from res (10, 126, 126, 64)
shape of feature map fed to attention (10, 126, 126, 64)


ResourceExhaustedError: Exception encountered when calling layer "conv2d_38" (type Conv2D).

OOM when allocating tensor with shape[10,126,126,2048] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Conv2D]

Call arguments received:
  • inputs=tf.Tensor(shape=(10, 126, 126, 64), dtype=float32)

In [None]:
import tensorflow as tf

class QueryConv(tf.keras.layers.Layer):
    def __init__(self, in_places, scale=8):
        super(QueryConv, self).__init__()
        self.query_conv = tf.keras.layers.Conv2D(in_places // scale, kernel_size=1, padding='same')

    def call(self, x):
        return self.query_conv(x)

class KeyConv(tf.keras.layers.Layer):
    def __init__(self, in_places, scale=8):
        super(KeyConv, self).__init__()
        self.key_conv = tf.keras.layers.Conv2D(in_places // scale, kernel_size=1, padding='same')

    def call(self, x):
        return self.key_conv(x)

class ValueConv(tf.keras.layers.Layer):
    def __init__(self, in_places, scale=8):
        super(ValueConv, self).__init__()
        self.value_conv = tf.keras.layers.Conv2D(in_places // scale, kernel_size=1, padding='same')

    def call(self, x):
        return self.value_conv(x)


@tf.function
def l2_norm(x):
    norm = tf.norm(x, ord=2, axis=-1, keepdims=True)
    print("Shape before normalization:", x.shape)
    print("Shape of norm:", norm.shape)
    return x / norm


# Sample input tensor
batch_size, channels, width, height = 10, 64, 128, 128
x = tf.random.normal((batch_size, width, height, channels))

# Your provided code snippet with added print statements
query_conv = QueryConv(in_places=channels)
key_conv = KeyConv(in_places=channels)
value_conv = ValueConv(in_places=channels)

Q = query_conv(x)
K = key_conv(x)
V = value_conv(x)
print("Shape after query_conv:", Q.shape)
print("Shape after key_conv:", K.shape)
print("Shape after value_conv:", V.shape)

Q = l2_norm(Q)
K = l2_norm(K)
print("Shape after l2_norm:", Q.shape, K.shape)

eps = 1e-6
tailor_sum = 1 / (width * height + tf.einsum("bnc, bc->bn", Q, tf.squeeze(tf.reduce_sum(K, axis=-1)) + eps))
value_sum = tf.expand_dims(tf.reduce_sum(V, axis=2), axis=-1)
value_sum = tf.tile(value_sum, [1, 1, width * height])
print("Shape after einsum and expand_dims:", tailor_sum.shape, value_sum.shape)

matrix = tf.einsum('bmn, bcn->bmc', K, V)
matrix_sum = value_sum + tf.einsum("bnm, bmc->bcn", Q, matrix)
print("Shape after einsum for matrix and matrix_sum:", matrix.shape, matrix_sum.shape)

weight_value = tf.einsum("bcn, bn->bcn", matrix_sum, tailor_sum)
print("Shape after einsum for weight_value:", weight_value.shape)

output = x + weight_value

print("Final output shape:", output.shape)
