<a href="https://colab.research.google.com/github/greyhound101/Multihead_attention/blob/master/concatenate_add.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tensorflow import keras
from tensorflow.keras import backend as K
class LayerNormalization(Layer):
    def __init__(self, eps=1e-6, **kwargs):
        self.eps = eps
        super(LayerNormalization, self).__init__(**kwargs)
    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
                                     initializer=Ones(), trainable=True)
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
                                    initializer=Zeros(), trainable=True)
        super(LayerNormalization, self).build(input_shape)
    def call(self, x):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta
    def compute_output_shape(self, input_shape):
        return input_shape
class abc(keras.layers.Layer):
    def __init__(self,
                 head_num,
                 q_k,
                 activation='relu',
                 use_bias=True,
                 kernel_initializer='glorot_normal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 history_only=False,
                 **kwargs):
        self.q_k=q_k
        self.supports_masking = True
        self.head_num = head_num
        self.activation = keras.activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.bias_initializer = keras.initializers.get(bias_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
        self.bias_regularizer = keras.regularizers.get(bias_regularizer)
        self.kernel_constraint = keras.constraints.get(kernel_constraint)
        self.bias_constraint = keras.constraints.get(bias_constraint)
        self.history_only = history_only

        self.Wq = self.Wk = self.Wv = self.Wo = None
        self.bq = self.bk = self.bv = self.bo = None

        self.intensity = self.attention = None
        super(abc, self).__init__(**kwargs)

    def get_config(self):
        config = {
            'head_num': self.head_num,
            'activation': keras.activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
            'bias_initializer': keras.initializers.serialize(self.bias_initializer),
            'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
            'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
            'bias_constraint': keras.constraints.serialize(self.bias_constraint),
            'history_only': self.history_only,
        }
        base_config = super(abc, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            q, k, v = input_shape
            return q[:-1] + (v[-1],)
        return input_shape

    def compute_mask(self, inputs, input_mask=None):
        if isinstance(input_mask, list):
            return input_mask[0]
        return input_mask

    def build(self, input_shape):
        self.layer_norm = LayerNormalization()
        if isinstance(input_shape, list):
            q, k, v = input_shape
        else:
            q = k = v = input_shape
        feature_dim = int(v[-1])
        if feature_dim % self.head_num != 0:
            raise IndexError('Invalid head number %d with the given input dim %d' % (self.head_num, feature_dim))
        self.Wq = self.add_weight(
            shape=(int(q[-1]), self.q_k),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wq' % self.name,
        )
        if self.use_bias:
            self.bq = self.add_weight(
                shape=(self.q_k,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bq' % self.name,
            )
        self.Wk = self.add_weight(
            shape=(int(k[-1]), self.q_k),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wk' % self.name,
        )
        if self.use_bias:
            self.bk = self.add_weight(
                shape=(self.q_k,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bk' % self.name,
            )
        self.Wv = self.add_weight(
            shape=(int(v[-1]), feature_dim),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wv' % self.name,
        )
        if self.use_bias:
            self.bv = self.add_weight(
                shape=(feature_dim,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bv' % self.name,
            )
        self.Wo = self.add_weight(
            shape=(feature_dim, feature_dim),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wo' % self.name,
        )
        if self.use_bias:
            self.bo = self.add_weight(
                shape=(feature_dim,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bo' % self.name,
            )
        super(abc, self).build(input_shape)

    @staticmethod
    def _reshape_to_batches(x, head_num):
        input_shape = K.shape(x)
        batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
        head_dim = feature_dim // head_num
        x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
        x = K.permute_dimensions(x, [0, 2, 1, 3])
        return K.reshape(x, (batch_size * head_num, seq_len, head_dim))

    @staticmethod
    def _reshape_attention_from_batches(x, head_num):
        input_shape = K.shape(x)
        batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
        x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))
        return K.permute_dimensions(x, [0, 2, 1, 3])

    @staticmethod
    def _reshape_from_batches(x, head_num):
        input_shape = K.shape(x)
        batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
        x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))
        x = K.permute_dimensions(x, [0, 2, 1, 3])
        return K.reshape(x, (batch_size // head_num, seq_len, feature_dim * head_num))

    @staticmethod
    def _reshape_mask(mask, head_num):
        if mask is None:
            return mask
        seq_len = K.shape(mask)[1]
        mask = K.expand_dims(mask, axis=1)
        mask = K.tile(mask, [1, head_num, 1])
        return K.reshape(mask, (-1, seq_len))

    def call(self, inputs, mask=None):
        if isinstance(inputs, list):
            q, k, v = inputs
        else:
            q = k = v = inputs
        if isinstance(mask, list):
            q_mask, k_mask, v_mask = mask
        else:
            q_mask = k_mask = v_mask = mask
        q = K.dot(q, self.Wq)
        k = K.dot(k, self.Wk)
        v = K.dot(v, self.Wv)
        if self.use_bias:
            q += self.bq
            k += self.bk
            v += self.bv
        if self.activation is not None:
            q = self.activation(q)
            k = self.activation(k)
            v = self.activation(v)
        def scaled_dot_product_attention(inputs):
          query, key, value = inputs
          feature_dim = K.shape(query)[-1]
          e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(self.q_k, dtype=K.floatx()))
          intensity = e
          e = K.exp(e - K.max(e, axis=-1, keepdims=True))
          attention = e / K.sum(e, axis=-1, keepdims=True)
          v = K.batch_dot(attention, value)
          return v,intensity,attention
       
       
        y,intensity,attention = scaled_dot_product_attention(
            inputs=[
                self._reshape_to_batches(q, self.head_num),
                self._reshape_to_batches(k, self.head_num),
                self._reshape_to_batches(v, self.head_num),
            ]
        )
        self.intensity = self._reshape_attention_from_batches(intensity, self.head_num)
        self.attention = self._reshape_attention_from_batches(attention, self.head_num)
        y = self._reshape_from_batches(y, self.head_num)
        y = K.dot(y, self.Wo)
        if self.use_bias:
            y += self.bo
        if self.activation is not None:
            y = self.activation(y)
        return y
from tensorflow import keras
from keras.activations import softmax
from tensorflow.keras import backend as K
import tensorflow as tf
class LayerNormalization(Layer):
    def __init__(self, eps=1e-6, **kwargs):
        self.eps = eps
        super(LayerNormalization, self).__init__(**kwargs)
    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
                                     initializer=Ones(), trainable=True)
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
                                    initializer=Zeros(), trainable=True)
        super(LayerNormalization, self).build(input_shape)
    def call(self, x):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta
    def compute_output_shape(self, input_shape):
        return input_shape
class abc(Layer):
    def __init__(self,inr,mo,up,**kwargs):
        super(abc, self).__init__(**kwargs)
        self.inr=inr
        self.mo=mo
        self.up=up

    def get_config(self):
        base_config = super(abc, self).get_config()

    def build(self, input_shape):
        super(abc, self).build(input_shape)
        self.cv1 = Conv2D(self.inr,1)
        self.cv2 = Conv2D(self.inr,1)
        self.cv3 = Conv2D(1,1)
        self.up = UpSampling2D(interpolation='bilinear',size=(self.up,self.up))
        self.dns1=Dense(1)
    def call(self, img,y):
        y = self.cv1(y)
        x = self.cv2(img)
        y = self.up(y)
        
        x = Add()([y,x])
        x = ReLU()(x)
        x = self.cv3(x)
        
        map = softmax(x,axis=[2,3])


        return tf.math.multiply(img,map)

class SpatialGate(keras.layers.Layer):
    def __init__(self,**kwargs):
        super(SpatialGate, self).__init__(**kwargs)

    def get_config(self):
        base_config = super(SpatialGate, self).get_config()

    def build(self, input_shape):
        super(SpatialGate, self).build(input_shape)
        self.cv2 = Conv2D(1,1)
    def call(self, img):
        
        img_avg = K.expand_dims(K.mean(img,-1),-1)
        img_max = K.expand_dims(K.max(img,-1),-1)
        total = Concatenate(-1)([img_avg,img_max])
        x = self.cv2(total)
        x = keras.activations.sigmoid(x)

        return tf.math.multiply(img,x)
class ChannelGate(keras.layers.Layer):
    def __init__(self,inr,ratio,**kwargs):
        super(ChannelGate, self).__init__(**kwargs)
        self.inr=inr
        self.ratio=ratio

    def get_config(self):
        base_config = super(abc, self).get_config()

    def build(self, input_shape):
        super(ChannelGate, self).build(input_shape)
        self.dns1 = Dense(self.inr/self.ratio,activation='relu')
        self.dns2 = Dense(self.inr)
        self.spt = SpatialGate()
    def call(self, img):
        
        img_avg = self.dns2(self.dns1(GlobalAveragePooling2D()(img)))
        img_max = self.dns2(self.dns1(GlobalMaxPooling2D()(img)))
        x = keras.activations.sigmoid(img_max+img_avg)
        x = Reshape((1,1,self.inr))(x)

        return self.spt(tf.math.multiply(img,x))

from tensorflow import keras 
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import *
from tensorflow.keras.applications import *
def load_model():   
  
  K.clear_session() 
  mod=densenet.DenseNet121(include_top=True, weights='imagenet')
  d = mod.get_layer('conv5_block16_concat').output
  d = Conv2D(512,1)(d)
  d = ChannelGate(512,8)(d)

  d = GlobalAveragePooling2D()(d)

  conc = Dense(3, activation="softmax")(d) 
  mod=Model(inputs=mod.input,outputs=conc)
  return mod


from tensorflow import keras 
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import *
from tensorflow.keras.applications import *
def load_model():   
  
  K.clear_session() 
  mod=densenet.DenseNet121(include_top=True, weights='imagenet')
  d = mod.get_layer('conv5_block16_concat').output
  d = ChannelGate(1024,8)(d)

  a = mod.get_layer('conv3_block12_concat').output
  a = ChannelGate(512,8)(a)

  b = mod.get_layer('conv4_block24_concat').output
  b = ChannelGate(1024,8)(b)
    
  a = GlobalAveragePooling2D()(a)

  b = GlobalAveragePooling2D()(b)

  d = GlobalAveragePooling2D()(d)

  conc=Concatenate(axis=1)([a,b,d])
  conc = Dense(3, activation="softmax")(conc) 
  
  mod=Model(inputs=mod.input,outputs=conc)
  return mod

Cloning into 'yolov5'...
remote: Enumerating objects: 9185, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 9185 (delta 0), reused 0 (delta 0), pack-reused 9184[K
Receiving objects: 100% (9185/9185), 9.60 MiB | 29.36 MiB/s, done.
Resolving deltas: 100% (6383/6383), done.
