<a href="https://colab.research.google.com/github/abubakrsiddq/ImageDehazing/blob/main/models/FFA-net/ffa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import glob
import random
from PIL import Image
import time
import datetime

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mean_squared_error
from tensorflow.keras.optimizers import Adam

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Preprocessing and loading of data

In [None]:
#ls drive/MyDrive/reside/archive/clear_images drive/MyDrive/reside/archive/haze  

In [3]:
# function to load the image in the form of tensors.

def load_image(img_path):
    img = tf.io.read_file(img_path)
    img = tf.io.decode_jpeg(img, channels = 3)
    img = tf.image.resize(img, size = (412, 548), antialias = True)
    img = img / 255.0
    img = tf.image.per_image_standardization(img)
    return img

In [4]:
def dataset_preposses(orig_path='/content/drive/MyDrive/dataset/clear_images',haze_path='/content/drive/MyDrive/dataset/haze',percentage=0.2,validation_size=64,test_size=64,seed_val=900):
  '''
  parameters:
  orig_path(string): path of ground truth folder
  haze_path(string): path of haze folder
  percentage(float): percentage of dataset to load
  validation_size(int): the no. of validation images
  test_size(int): the no. of test images

  returns:
  haze_list,validation_list,test_list
  '''
  random.seed(seed_val)
  pth=haze_path+'/*.jpg'
  haze_path_list = glob.glob(pth)
  orig_path_list=glob.glob(orig_path+'/*.jpg')
  #print(orig_path_list)
  random.shuffle(haze_path_list)
  #print(haze_path_list)
  haze_path_dict={}
  haze_count_dict={}
  haze_list=[]
  no_per_set=int(percentage*35)
  for i in haze_path_list:
    name=i.split('/')[-1].split('_')[0]
    if(int(name)>468):
      try:
        if(haze_count_dict[name]<no_per_set):
          haze_path_dict[name].append(i)
          
          haze_count_dict[name]+=1;
          
      except KeyError:
       
        haze_path_dict[name]=[]
        haze_path_dict[name].append(i)
        haze_count_dict[name]=1
    #print(haze_path_dict)
  test_list=haze_path_list[-1*test_size:]
  val_list=haze_path_list[-1*(validation_size+test_size):-1*test_size];

  for (key,val) in haze_path_dict.items():
    for i in val:
      haze_list.append(i)
  return haze_list,val_list,test_list


def gen_dataset(ar):
  '''
  parameters
  list of paths
  return
  list with gt attached 
  '''
  orig_path='/content/drive/MyDrive/dataset/clear_images'
  haze_pth='/content/drive/MyDrive/dataset/haze'
  lst=[]
  for i in ar:
    name=i.split('/')[-1].split('_')[0]
    pthlist=[i,orig_path+'/'+name+'.jpg']
    lst.append(pthlist)
  return lst

def data_path(orig_img_path = './drive/MyDrive/reside/archive/clear_images', hazy_img_path = './drive/MyDrive/reside/archive/haze'):
  
  (a,b,c)=dataset_preposses(orig_path=orig_img_path,haze_path=hazy_img_path)
  a=gen_dataset(a)
  b=gen_dataset(b)
  return a,b

In [5]:
# function to load tensor image data in batches.

def dataloader(train_data, val_data, batch_size):
    print(len(train_data))
    train_data_orig = tf.data.Dataset.from_tensor_slices([img[1] for img in train_data]).map(lambda x: load_image(x))
    train_data_haze = tf.data.Dataset.from_tensor_slices([img[0] for img in train_data]).map(lambda x: load_image(x))
    train = tf.data.Dataset.zip((train_data_haze, train_data_orig)).shuffle(buffer_size=100).batch(batch_size)
    
    val_data_orig = tf.data.Dataset.from_tensor_slices([img[1] for img in val_data]).map(lambda x: load_image(x))
    val_data_haze = tf.data.Dataset.from_tensor_slices([img[0] for img in val_data]).map(lambda x: load_image(x))
    val = tf.data.Dataset.zip((val_data_haze, val_data_orig)).shuffle(buffer_size=100).batch(batch_size)
    
    return train, val

In [6]:
# function to display output.
import cv2
def display_img(model, hazy_img, orig_img):
    
    dehazed_img = model(hazy_img, training = True)
    plt.figure(figsize = (15,15))
    
    display_list = [hazy_img[0], orig_img[0], dehazed_img[0]]
    title = ['Hazy Image', 'Ground Truth', 'Dehazed Image']
    
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
        
    plt.show()
    #print("input image quality",display_list)#niqe(cv2.imread(display_list[1])))
    #print("input image quality",niqe(cv2.imread(display_list[2])))

# Network Function

In [3]:
def default_conv(in_channels, out_channels, kernel_size, bias=True,activation='relu'):
    return tf.keras.layers.Conv2D(out_channels, kernel_size,padding='same', use_bias=bias,activation=activation)
    
class PixAtLayer(tf.keras.Model):
    def __init__(self, channel):
        super(PixAtLayer, self).__init__()
        self.pa = tf.keras.Sequential()
        self.pa.add(tf.keras.layers.Conv2D(channel // 8, 1, padding='valid',activation='relu'))
        self.pa.add(tf.keras.layers.Conv2D( 1, 1,activation='sigmoid'))
    def call(self, x):
        y = self.pa(x)
        #return y
        return x * y

    def model(self):
        x = Input(shape = (352, 1216, 3))
        return Model(inputs=[x], outputs=self.call(x))

class adapavgpooling(tf.keras.Model):
  def __init__(self,outputsize):
    super(adapavgpooling,self).__init__()
    self.outputsize=outputsize

  def call(self,x):
    x_shape=tf.keras.backend.int_shape(x)
    batchsize1,dim1,dim2,channels1=x_shape
    stride=np.floor(dim2/self.outputsize).astype(np.int32)
    kernels=dim1-(self.outputsize-1)*stride
    adpooling=tf.keras.layers.AveragePooling2D(pool_size=(kernels,kernels),strides=(stride,stride))(x)
    return adpooling

  def model(self):
        x = Input(shape = (352, 1216, 3))
        return Model(inputs=[x], outputs=self.call(x))
'''    
def adapavgpooling(x,outsize):
    x_shape=tf.keras.backend.int_shape(x)
    batchsize1,dim1,dim2,channels1=x_shape
    stride=np.floor(dim1/outsize).astype(np.int32)
    kernels=dim1-(outsize-1)*stride
    adpooling=tf.keras.layers.AveragePooling2D(pool_size=(kernels,kernels),strides=(stride,stride))(x)
    
    return adpooling
'''
class ChanAtLayer(tf.keras.Model):
  def __init__(self, channel):
      super(ChanAtLayer, self).__init__()
      #self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
      self.ca = tf.keras.Sequential()
      self.ca.add(tf.keras.layers.Conv2D(channel // 8, 1,activation='relu'))
      self.ca.add(tf.keras.layers.Conv2D(channel, 1, activation='sigmoid'))
      self.adPool=adapavgpooling(1)
  def call(self, x):
      y = self.adPool(x)
      #print(y.shape)
      y = self.ca(y)
      
      return x * y

  def model(self):
        x = Input(shape = (352, 1216, 64))
        return Model(inputs=[x], outputs=self.call(x))


In [4]:
#sub = PixAtLayer(64)
#sub.model().summary()
#sub1 = ChanAtLayer(64)
#sub1.model().summary()

In [5]:
class Block_layer(tf.keras.Model):
    def __init__(self, conv, dim, kernel_size,):
        super(Block_layer, self).__init__()
        self.conv1=conv(dim,dim, kernel_size, bias=True,activation='relu')
        
        self.conv2=conv(dim,dim, kernel_size, bias=True)
        self.calayer=ChanAtLayer(dim)
        self.palayer=PixAtLayer(dim)
    def call(self, x):
        res=self.conv1(x)
        res=res+x 
        res=self.conv2(res)
        res=self.calayer(res)
        res=self.palayer(res)
        res += x 
        return res

    def model(self):
        x = Input(shape = (352, 1216, 64))
        return Model(inputs=[x], outputs=self.call(x))


In [6]:
#sub1 =Block_layer(default_conv,64,3)
#sub1.model().summary()

In [7]:
class Group_layer(tf.keras.Model):
    def __init__(self, conv, dim, kernel_size, blocks):
        super(Group_layer, self).__init__()
        modules = [ Block_layer(conv, dim, kernel_size)  for _ in range(blocks)]
        modules.append(tf.keras.layers.Conv2D(dim, kernel_size,padding='same'))
        self.gp = tf.keras.Sequential()
        for lay in modules:
          self.gp.add(lay)
        
    def call(self,input_tensor):
        res = self.gp(input_tensor)
        #res =tf.keras.layers.Add()([res,input_tensor])
        res+=input_tensor
        return res

    def model(self):
        x = Input(shape = (352, 1216, 64))
        return Model(inputs=[x], outputs=self.call(x))

In [8]:
#sub1 =Group_layer(default_conv,64,3,6)
#sub1.model().summary()

In [9]:
class FFAnet(tf.keras.Model):
    def __init__(self,gps,blocks,conv=default_conv):
        super(FFAnet, self).__init__()
        # define all layers in init
        # Layer of Block 1
        self.gps=gps
        self.dim=64
        kernel_size=3
        pre_process = [tf.keras.layers.Conv2D(self.dim, kernel_size)]
        assert self.gps==3
        self.g1= Group_layer(conv, self.dim, kernel_size,blocks=blocks)
        self.g2= Group_layer(conv, self.dim, kernel_size,blocks=blocks)
        self.g3= Group_layer(conv, self.dim, kernel_size,blocks=blocks)
        l=[
            adapavgpooling(1),
            tf.keras.layers.Conv2D(self.dim//16,1,padding='valid'),
            
            tf.keras.layers.Conv2D(self.dim*self.gps, 1, padding='valid', use_bias=True,activation='sigmoid')
            
            ]
       
        self.ca=tf.keras.Sequential()
        for lay in l:
          self.ca.add(lay)
        self.palayer=PixAtLayer(self.dim)

        post_precess = [
            conv(self.dim, self.dim, kernel_size),
            conv(self.dim, 3, kernel_size)]

        self.pre = tf.keras.Sequential(conv(3,self.dim, kernel_size),name='preprocess')

        self.post = tf.keras.Sequential(name='postprocess')
        for lay in post_precess:
          self.post.add(lay)
        
    def call(self, input_tensor = (None,None, 3), training=False):
        #input_tensor=tf.keras.layers.InputLayer(input_shape=(352,1216,3))(input_tensor)
        # forward pass: block 1 
        input_tensor=tf.keras.layers.experimental.preprocessing.Normalization(axis=-1,mean=[0.485, 0.456, 0.406], variance=[0.229**2, 0.224**2, 0.225**2])(input_tensor)#self.norm(input_tensor)
        x = self.pre(input_tensor)
        res1=self.g1(x)
        res2=self.g2(res1)
        res3=self.g3(res2)

        #return res3
        w=self.ca(tf.keras.layers.concatenate([res1,res2,res3],axis=-1))
        w=tf.keras.layers.Reshape((1,self.gps,self.dim))(w)
        out=w[:,:,0]*res1+w[:,:,1]*res2+w[:,:,2]*res3
        out=self.palayer(out)
        x=self.post(out)
        return x + input_tensor
        #return out
       
    def model(self):
        x = Input(shape = (352, 1216, 3))
        return Model(inputs=[x], outputs=self.call(x))

    def build_graph(self):
        x = Input(shape=(352, 1216, 3))
        return Model(inputs=[x], outputs=self.call(x))

In [10]:
model=FFAnet(gps=3,blocks=19)
model.model().summary()
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(
    model.build_graph(),                      # here is the trick (for now)
    to_file='model.png', dpi=96,              # saving  
    show_shapes=True, show_layer_names=True,  # show shapes and layer name
    expand_nested=False                       # will show nested block
)
model.load_weights('/content/drive/MyDrive/nets/ffa/weights')

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 352, 1216, 3 0                                            
__________________________________________________________________________________________________
normalization (Normalization)   (None, 352, 1216, 3) 7           input_1[0][0]                    
__________________________________________________________________________________________________
preprocess (Sequential)         (None, 352, 1216, 64 1792        normalization[0][0]              
__________________________________________________________________________________________________
group_layer (Group_layer)       (None, 352, 1216, 64 1471067     preprocess[0][0]                 
______________________________________________________________________________________________

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fafce9a5390>

# Transfer learning

In [15]:
import torch
pretrain=torch.load('/content/drive/MyDrive/ots_train_ffa_3_19.pk',map_location=torch.device('cpu'))
pretrain.keys()

dict_keys(['step', 'max_psnr', 'max_ssim', 'ssims', 'psnrs', 'losses', 'model'])

In [None]:
pretrain['model'].keys()

odict_keys(['module.g1.gp.0.conv1.weight', 'module.g1.gp.0.conv1.bias', 'module.g1.gp.0.conv2.weight', 'module.g1.gp.0.conv2.bias', 'module.g1.gp.0.calayer.ca.0.weight', 'module.g1.gp.0.calayer.ca.0.bias', 'module.g1.gp.0.calayer.ca.2.weight', 'module.g1.gp.0.calayer.ca.2.bias', 'module.g1.gp.0.palayer.pa.0.weight', 'module.g1.gp.0.palayer.pa.0.bias', 'module.g1.gp.0.palayer.pa.2.weight', 'module.g1.gp.0.palayer.pa.2.bias', 'module.g1.gp.1.conv1.weight', 'module.g1.gp.1.conv1.bias', 'module.g1.gp.1.conv2.weight', 'module.g1.gp.1.conv2.bias', 'module.g1.gp.1.calayer.ca.0.weight', 'module.g1.gp.1.calayer.ca.0.bias', 'module.g1.gp.1.calayer.ca.2.weight', 'module.g1.gp.1.calayer.ca.2.bias', 'module.g1.gp.1.palayer.pa.0.weight', 'module.g1.gp.1.palayer.pa.0.bias', 'module.g1.gp.1.palayer.pa.2.weight', 'module.g1.gp.1.palayer.pa.2.bias', 'module.g1.gp.2.conv1.weight', 'module.g1.gp.2.conv1.bias', 'module.g1.gp.2.conv2.weight', 'module.g1.gp.2.conv2.bias', 'module.g1.gp.2.calayer.ca.0.weight'

In [54]:
w=np.transpose(pretrain['model']['module.post.1.weight'],[2,3,1,0])
print(w.shape)
b=pretrain['model']['module.post.1.bias'].numpy()
print(b.shape)

torch.Size([3, 3, 64, 3])
(3,)


In [42]:
w=np.transpose(pretrain['model']['module.palayer.pa.2.weight'],[2,3,1,0])
print(w.shape)
b=pretrain['model']['module.palayer.pa.2.bias'].numpy()
print(b.shape)

torch.Size([1, 1, 8, 1])
(1,)


In [43]:
model.layers[-3].layers[0].layers[1].set_weights([w,b])

In [55]:
model.layers[-1].layers[1].set_weights([w,b])

In [25]:

w=np.transpose(pretrain['model']['module.g3.gp.19.weight'],[2,3,1,0])
print(w.shape)
b=pretrain['model']['module.g3.gp.19.bias'].numpy()
print(b.shape)

torch.Size([3, 3, 64, 64])
(64,)


In [26]:
model.layers[2].layers[0].layers[-1].set_weights([w,b])

In [16]:
def set_weight(grp,block):
  keyFind='module.g'+str(grp)+'.gp.'+str(block)+'.'
  w=np.transpose(pretrain['model'][keyFind+'conv1.weight'],[2,3,1,0])
  b=pretrain['model'][keyFind+'conv1.bias'].numpy()
  model.layers[grp-1].layers[0].layers[block].layers[0].set_weights([w,b])

  w=np.transpose(pretrain['model'][keyFind+'conv2.weight'],[2,3,1,0])
  b=pretrain['model'][keyFind+'conv2.bias'].numpy()
  model.layers[grp-1].layers[0].layers[block].layers[1].set_weights([w,b])

  w=np.transpose(pretrain['model'][keyFind+'calayer.ca.0.weight'],[2,3,1,0])
  b=pretrain['model'][keyFind+'calayer.ca.0.bias'].numpy()
  model.layers[grp-1].layers[0].layers[block].layers[2].layers[0].layers[0].set_weights([w,b])

  w=np.transpose(pretrain['model'][keyFind+'calayer.ca.2.weight'],[2,3,1,0])
  b=pretrain['model'][keyFind+'calayer.ca.2.bias'].numpy()
  model.layers[grp-1].layers[0].layers[block].layers[2].layers[0].layers[1].set_weights([w,b])

  w=np.transpose(pretrain['model'][keyFind+'palayer.pa.0.weight'],[2,3,1,0])
  b=pretrain['model'][keyFind+'palayer.pa.0.bias'].numpy()
  model.layers[grp-1].layers[0].layers[block].layers[3].layers[0].layers[0].set_weights([w,b])


  w=np.transpose(pretrain['model'][keyFind+'palayer.pa.2.weight'],[2,3,1,0])
  b=pretrain['model'][keyFind+'palayer.pa.2.bias'].numpy()
  model.layers[grp-1].layers[0].layers[block].layers[3].layers[0].layers[1].set_weights([w,b])

In [19]:
for i in range(19):
  set_weight(3,i)

In [56]:
model.save_weights('/content/drive/MyDrive/nets/ffa/weights')

# Training

In [None]:
# Hyperparameters
epochs = 10
batch_size = 1

train_data, val_data = data_path(orig_img_path = './drive/MyDrive/dataset/clear_images', hazy_img_path = './drive/MyDrive/dataset/haze',)
train, val = dataloader(train_data, val_data, batch_size)

optimizer = Adam(learning_rate = 1e-4)
net = FFAnet(gps=3,blocks=8)
net.load_weights('/content/drive/MyDrive/nets/ffa/weights')
train_loss_tracker = tf.keras.metrics.MeanSquaredError(name = "train loss")
val_loss_tracker = tf.keras.metrics.MeanSquaredError(name = "val loss")

2604


In [None]:
def train_model(epochs, train, val, net, train_loss_tracker, val_loss_tracker, optimizer):
    
    for epoch in range(epochs):
        
        print("\nStart of epoch %d" % (epoch,), end=' ')
        start_time_epoch = time.time()
        start_time_step = time.time()
        
        # training loop
        
        for step, (train_batch_haze, train_batch_orig) in enumerate(train):

            with tf.GradientTape() as tape:

                train_logits = net(train_batch_haze, training = True)
                #print(train_logits.shape)
                loss = mean_squared_error(train_batch_orig, train_logits)

            grads = tape.gradient(loss, net.trainable_weights)
            optimizer.apply_gradients(zip(grads, net.trainable_weights))

            train_loss_tracker.update_state(train_batch_orig, train_logits)
            if step == 0:
                print('[', end='')
            if step % 64 == 0:
                print('=', end='')
        
        print(']', end='')
        print('  -  ', end='')
        print('Training Loss: %.4f' % (train_loss_tracker.result()), end='')
        
        # validation loop
        
        for step, (val_batch_haze, val_batch_orig) in enumerate(val):
            val_logits = net(val_batch_haze, training = False)
            val_loss_tracker.update_state(val_batch_orig, val_logits)
            
            if step % 32 ==0:
                display_img(net, val_batch_haze, val_batch_orig)
        
        print('  -  ', end='')
        print('Validation Loss: %.4f' % (val_loss_tracker.result()), end='')
        print('  -  ', end=' ')
        print("Time taken: %.2fs" % (time.time() - start_time_epoch))
        
        net.save('trained_model')  
        #net.save_weights('/content/drive/MyDrive/nets/ffa/weights')         # save the model(variables, weights, etc)
        train_loss_tracker.reset_states()
        val_loss_tracker.reset_states()

In [None]:
%%time
train_model(epochs, train, val, net, train_loss_tracker, val_loss_tracker, optimizer)

In [None]:
net.save('/content/drive/MyDrive/nets/ffa')

net.save_weights('/content/drive/MyDrive/nets/ffa/weights')
#model=net

#model.model().summary()

In [11]:
import cv2
import matplotlib.pyplot as plt

def evaluate_gen(net):
    
    #test_img = glob.glob(test_img_path +'/*.jpg')
    test_img=glob.glob('/content/drive/MyDrive/Final_compare/HAZY/*.jpg')
    #random.shuffle(test_img)
    i=0;
    for img in test_img:
        print(img)
        img = tf.io.read_file(img)
        img = tf.io.decode_jpeg(img, channels = 3)
        
        img = tf.image.resize(img, size = (412,548), antialias = True)
        
        img = img / 255.0
        img=tf.image.per_image_standardization(img)
        print(i,end=" ")
        img = tf.expand_dims(img, axis = 0)      #transform input image from 3D to 4D ###
        
        dehaze =net(img, training = True)
        dehaze=tf.image.resize(dehaze, size = (413,550), antialias = True)
        #plt.figure(figsize = (80, 80))
        
        #display_list = [img[0], dehaze[0]]       #make the first dimension zero
        im=dehaze[0]
        directory = '/content/drive/MyDrive/Final_compare/ffa'
        os.chdir(directory)
        filename = str(i) + '_outdoor_gen.jpg'
        #print(filename)
        #cv2.imwrite(filename,im) 
        #plt.imsave(filename,im)
        tf.keras.preprocessing.image.save_img(
    filename, 

im)

        os.chdir('/content')
        i+=1;



In [None]:
#new_model = tf.keras.models.load_model('/content/drive/MyDrive/nets/gca',compile=False)
#new_model = tf.keras.models.load_model('/content/drive/MyDrive/nets/gca/trained_model',compile=False)
#new_model = tf.keras.models.load_model('/content/drive/MyDrive/nets/test_custom_loss_net',compile=False)

#net = FFAnet(gps=3,blocks=19)
#net.load_weights('/content/drive/MyDrive/nets/ffa/weights')
#evaluate_gen(new_model)
evaluate_gen(model)

/content/drive/MyDrive/Final_compare/HAZY/0_outdoor_hazy.jpg
0 /content/drive/MyDrive/Final_compare/HAZY/4_outdoor_hazy.jpg
1 /content/drive/MyDrive/Final_compare/HAZY/2_outdoor_hazy.jpg
2 /content/drive/MyDrive/Final_compare/HAZY/5_outdoor_hazy.jpg
3 /content/drive/MyDrive/Final_compare/HAZY/3_outdoor_hazy.jpg
4 /content/drive/MyDrive/Final_compare/HAZY/1_outdoor_hazy.jpg
5 /content/drive/MyDrive/Final_compare/HAZY/6_outdoor_hazy.jpg
6 /content/drive/MyDrive/Final_compare/HAZY/7_outdoor_hazy.jpg
7 /content/drive/MyDrive/Final_compare/HAZY/8_outdoor_hazy.jpg
8 /content/drive/MyDrive/Final_compare/HAZY/9_outdoor_hazy.jpg
9 /content/drive/MyDrive/Final_compare/HAZY/10_outdoor_hazy.jpg
10 /content/drive/MyDrive/Final_compare/HAZY/11_outdoor_hazy.jpg
11 /content/drive/MyDrive/Final_compare/HAZY/12_outdoor_hazy.jpg
12 /content/drive/MyDrive/Final_compare/HAZY/13_outdoor_hazy.jpg
13 /content/drive/MyDrive/Final_compare/HAZY/14_outdoor_hazy.jpg
14 /content/drive/MyDrive/Final_compare/HAZY/15_

## Learn

In [None]:
import torch.nn as nn
import torch

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)
    
class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
                nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
                nn.Sigmoid()
        )
    def forward(self, x):
        y = self.pa(x)
        return x * y

class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class Block(nn.Module):
    def __init__(self, conv, dim, kernel_size,):
        super(Block, self).__init__()
        self.conv1=conv(dim, dim, kernel_size, bias=True)
        self.act1=nn.ReLU(inplace=True)
        self.conv2=conv(dim,dim,kernel_size,bias=True)
        self.calayer=CALayer(dim)
        self.palayer=PALayer(dim)
    def forward(self, x):
        res=self.act1(self.conv1(x))
        res=res+x 
        res=self.conv2(res)
        res=self.calayer(res)
        res=self.palayer(res)
        res += x 
        return res
class Group(nn.Module):
    def __init__(self, conv, dim, kernel_size, blocks):
        super(Group, self).__init__()
        modules = [ Block(conv, dim, kernel_size)  for _ in range(blocks)]
        modules.append(conv(dim, dim, kernel_size))
        self.gp = nn.Sequential(*modules)
    def forward(self, x):
        res = self.gp(x)
        res += x
        return res

class FFA(nn.Module):
    def __init__(self,gps,blocks,conv=default_conv):
        super(FFA, self).__init__()
        self.gps=gps
        self.dim=64
        kernel_size=3
        pre_process = [conv(3, self.dim, kernel_size)]
        assert self.gps==3
        self.g1= Group(conv, self.dim, kernel_size,blocks=blocks)
        self.g2= Group(conv, self.dim, kernel_size,blocks=blocks)
        self.g3= Group(conv, self.dim, kernel_size,blocks=blocks)
        self.ca=nn.Sequential(*[
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.dim*self.gps,self.dim//16,1,padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.dim//16, self.dim*self.gps, 1, padding=0, bias=True),
            nn.Sigmoid()
            ])
        self.palayer=PALayer(self.dim)

        post_precess = [
            conv(self.dim, self.dim, kernel_size),
            conv(self.dim, 3, kernel_size)]

        self.pre = nn.Sequential(*pre_process)
        self.post = nn.Sequential(*post_precess)

    def forward(self, x1):
        x = self.pre(x1)
        res1=self.g1(x)
        res2=self.g2(res1)
        res3=self.g3(res2)
        #return res3
        w=torch.cat([res1,res2,res3],dim=1)
        w=self.ca(w)
        #return w
        w=w.view(-1,self.gps,self.dim)[:,:,:,None,None]
        #print(h.shape)
        #return h
        
        out=w[:,0,::]*res1+w[:,1,::]*res2+w[:,2,::]*res3
        
        #return out
        out=self.palayer(out)
        #return out
        x=self.post(out)
        return x + x1

In [None]:
from torchsummary import summary
net=FFA(gps=3,blocks=19)
#if torch.cuda.is_available():
#net.cuda(device='cpu')
#print(summary(net, (3, 412, 548)))

In [None]:
!git clone https://github.com/gzuidhof/nn-transfer.git

Cloning into 'nn-transfer'...
remote: Enumerating objects: 190, done.[K
remote: Total 190 (delta 0), reused 0 (delta 0), pack-reused 190[K
Receiving objects: 100% (190/190), 35.62 KiB | 1.42 MiB/s, done.
Resolving deltas: 100% (107/107), done.


In [None]:
pip install ./nn-transfer/.

Processing ./nn-transfer
Building wheels for collected packages: nn-transfer
  Building wheel for nn-transfer (setup.py) ... [?25l[?25hdone
  Created wheel for nn-transfer: filename=nn_transfer-0.1.0-cp37-none-any.whl size=4370 sha256=bbc87a106758507340c1af338aaf074bafd2ea0e1d24526e17974d6dd61cfa39
  Stored in directory: /root/.cache/pip/wheels/df/e8/e3/8d4084052a5c54875f1ba69fefcc39f3568d088c0fe3e1ed2e
Successfully built nn-transfer
Installing collected packages: nn-transfer
Successfully installed nn-transfer-0.1.0


In [None]:
ls

0_outdoor_gen.jpg   1_outdoor_gen.jpg   2_outdoor_gen.jpg   3_outdoor_gen.jpg
10_outdoor_gen.jpg  20_outdoor_gen.jpg  30_outdoor_gen.jpg  40_outdoor_gen.jpg
11_outdoor_gen.jpg  21_outdoor_gen.jpg  31_outdoor_gen.jpg  41_outdoor_gen.jpg
12_outdoor_gen.jpg  22_outdoor_gen.jpg  32_outdoor_gen.jpg  4_outdoor_gen.jpg
13_outdoor_gen.jpg  23_outdoor_gen.jpg  33_outdoor_gen.jpg  5_outdoor_gen.jpg
14_outdoor_gen.jpg  24_outdoor_gen.jpg  34_outdoor_gen.jpg  6_outdoor_gen.jpg
15_outdoor_gen.jpg  25_outdoor_gen.jpg  35_outdoor_gen.jpg  7_outdoor_gen.jpg
16_outdoor_gen.jpg  26_outdoor_gen.jpg  36_outdoor_gen.jpg  8_outdoor_gen.jpg
17_outdoor_gen.jpg  27_outdoor_gen.jpg  37_outdoor_gen.jpg  9_outdoor_gen.jpg
18_outdoor_gen.jpg  28_outdoor_gen.jpg  38_outdoor_gen.jpg  [0m[01;34mnn-transfer[0m/
19_outdoor_gen.jpg  29_outdoor_gen.jpg  39_outdoor_gen.jpg


In [None]:
a=np.random.randn(3,3,64,1)

In [None]:
np.transpose(a,[3,2,0,1]).shape

(1, 64, 3, 3)