In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
!pip install -U tensorflow_addons

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.16.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[?25l[K     |▎                               | 10 kB 11.4 MB/s eta 0:00:01[K     |▋                               | 20 kB 17.9 MB/s eta 0:00:01[K     |▉                               | 30 kB 12.9 MB/s eta 0:00:01[K     |█▏                              | 40 kB 9.7 MB/s eta 0:00:01[K     |█▌                              | 51 kB 3.8 MB/s eta 0:00:01[K     |█▊                              | 61 kB 4.4 MB/s eta 0:00:01[K     |██                              | 71 kB 4.7 MB/s eta 0:00:01[K     |██▍                             | 81 kB 4.4 MB/s eta 0:00:01[K     |██▋                             | 92 kB 4.9 MB/s eta 0:00:01[K     |███                             | 102 kB 4.3 MB/s eta 0:00:01[K     |███▏                            | 112 kB 4.3 MB/s eta 0:00:01[K     |███▌                            | 122 kB 4.3 MB/s eta 0:00:01[K     |███

In [None]:
#@title Import Packages

from keras.layers import Conv2D, ReLU, BatchNormalization, Add, Subtract, Concatenate, Input, MaxPooling2D, Layer, InputSpec, Conv2DTranspose, GaussianNoise, GaussianDropout, UpSampling2D, LeakyReLU, AveragePooling2D
from keras.optimizers import adam_v2
from keras.models import Model, load_model
from keras import initializers, regularizers, constraints
from keras import backend as K

import tensorflow_addons as tfa
import tensorflow as tf

print(tfa.__version__)

from PIL import Image
import numpy as np
import os

0.16.1


In [None]:
#@title Instance Normalization Block

class InstanceNormalization(Layer):
    """Instance normalization layer.
    Normalize the activations of the previous layer at each step,
    i.e. applies a transformation that maintains the mean activation
    close to 0 and the activation standard deviation close to 1.
    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=1` in `InstanceNormalization`.
            Setting `axis=None` will normalize all values in each
            instance of the batch.
            Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
        epsilon: Small float added to variance to avoid dividing by zero.
        center: If True, add offset of `beta` to normalized tensor.
            If False, `beta` is ignored.
        scale: If True, multiply by `gamma`.
            If False, `gamma` is not used.
            When the next layer is linear (also e.g. `nn.relu`),
            this can be disabled since the scaling
            will be done by the next layer.
        beta_initializer: Initializer for the beta weight.
        gamma_initializer: Initializer for the gamma weight.
        beta_regularizer: Optional regularizer for the beta weight.
        gamma_regularizer: Optional regularizer for the gamma weight.
        beta_constraint: Optional constraint for the beta weight.
        gamma_constraint: Optional constraint for the gamma weight.
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a Sequential model.
    # Output shape
        Same shape as input.
    # References
        - [Layer Normalization](https://arxiv.org/abs/1607.06450)
        - [Instance Normalization: The Missing Ingredient for Fast Stylization](
        https://arxiv.org/abs/1607.08022)
    """
    def __init__(self,
                 axis=None,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)
 
    def build(self, input_shape):
        ndim = len(input_shape)
        if self.axis == 0:
            raise ValueError('Axis cannot be zero')
 
        if (self.axis is not None) and (ndim == 2):
            raise ValueError('Cannot specify axis for rank 1 tensor')
 
        self.input_spec = InputSpec(ndim=ndim)
 
        if self.axis is None:
            shape = (1,)
        else:
            shape = (input_shape[self.axis],)
 
        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True
 
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        reduction_axes = list(range(0, len(input_shape)))
 
        if self.axis is not None:
            del reduction_axes[self.axis]
 
        del reduction_axes[0]
 
        mean = K.mean(inputs, reduction_axes, keepdims=True)
        stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs - mean) / stddev
 
        broadcast_shape = [1] * len(input_shape)
        if self.axis is not None:
            broadcast_shape[self.axis] = input_shape[self.axis]
 
        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            normed = normed * broadcast_gamma
        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            normed = normed + broadcast_beta
        return normed
 
    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(InstanceNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
#@title Convolution Block

def conv_block(x,filter_size,k_size=(3,3), stride=(1,1), lk_alpha=0):
  lane=Conv2D(filter_size, k_size, padding='same', strides=stride)(x)
  lane=InstanceNormalization(axis=-1)(lane)
  lane=LeakyReLU(alpha=lk_alpha)(lane)

  return lane

In [None]:
#@title Resnet Blocks

def rn_block(x,filter_size,k_size=(3,3), down=False, lk_alpha=0):
  lane=Conv2D(filter_size, k_size, padding='same')(x)
  lane=InstanceNormalization(axis=-1)(lane)
  lane=LeakyReLU(alpha=lk_alpha)(lane)
  lane=Conv2D(filter_size, k_size, padding='same')(lane)
  lane=InstanceNormalization(axis=-1)(lane)

  if down:
    x=Conv2D(filter_size, (1,1))(x)

  lane=Add()([lane,x])

  lane=LeakyReLU(alpha=lk_alpha)(lane)

  return lane

def rn_sequence(lane, round, depth, k_size, lk_alpha):
  for i in range(round):
    if i==0:
      lane=rn_block(lane,depth,k_size,lk_alpha=lk_alpha,down=True)
    else:
      lane=rn_block(lane,depth,k_size,lk_alpha=lk_alpha)

  return lane

In [None]:
#@title Resnet-based Patch Discriminator

def rn_patch_d(input_shape,opt,k_size,init_k=32,lk_alpha=0,multi_stage=False):
  input=Input(shape=input_shape)

  lane=rn_block(input,init_k,(3,3), True)

  lane=rn_sequence(lane, 2, init_k, k_size, lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #128

  init_k*=2

  lane=rn_sequence(lane, 2, init_k, k_size, lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #64

  init_k*=2

  lane=rn_sequence(lane, 2, init_k, k_size, lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #32
  
  init_k*=2

  lane=rn_sequence(lane, 2, init_k, k_size, lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #16

  init_k*=2

  lane=rn_sequence(lane, 2, init_k, k_size, lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #8

  output=Conv2D(1,(5,5),activation='sigmoid',padding='same')(lane)

  net=Model(input,output)
  net.compile(loss='binary_crossentropy', optimizer=opt)

  return net

In [None]:
#@title Regular Patch Discriminator 

def rg_patch_d(input_shape,opt,k_size,lk_alpha,init_k=32):
  input=Input(shape=input_shape)

  lane=conv_block(input,init_k,k_size,lk_alpha=lk_alpha)
  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #128

  init_k*=2

  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #64

  init_k*=2

  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #32

  init_k*=2

  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #16

  init_k*=2

  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=conv_block(lane,init_k,k_size,lk_alpha=lk_alpha)
  lane=MaxPooling2D(pool_size=(2,2), strides=(2,2), padding="same")(lane) #8
  
  output=Conv2D(1,(7,7),activation='sigmoid',padding='same')(lane)

  net=Model(input,output)
  net.compile(loss='binary_crossentropy', optimizer=opt)

  return net

In [None]:
#@title Pix2Pix Generator

def pp2_generator(input_shape,output_k,lk_alpha=0, k_size=(3,3), depth=32, gdrop=False):
  input=Input(input_shape)

  lane=conv_block(input,depth,k_size,stride=(2,2), lk_alpha=lk_alpha) #128

  if gdrop:
    lane=GaussianDropout(0.1)(lane)

  depth*=2

  lane=conv_block(lane,depth,k_size,stride=(2,2), lk_alpha=lk_alpha) #64

  if gdrop:
    lane=GaussianDropout(0.1)(lane)

  # lane_1=lane
  depth*=2

  lane=conv_block(lane,depth,k_size,stride=(2,2), lk_alpha=lk_alpha) #32

  if gdrop:
    lane=GaussianDropout(0.1)(lane)
    
  # lane_2=lane
  depth*=2

  lane=rn_sequence(lane, 9, depth, k_size, lk_alpha)

  # lane=Concatenate()([lane,lane_2])
  depth/=2

  lane=Conv2DTranspose(depth,(2,2),strides=(2,2),padding='same')(lane)
  # lane=UpSampling2D(size=(2,2), interpolation='nearest')(lane)

  # lane=Concatenate()([lane,lane_1])
  depth/=2

  lane=Conv2DTranspose(depth,(2,2),strides=(2,2),padding='same')(lane)
  # lane=UpSampling2D(size=(2,2), interpolation='nearest')(lane)

  depth/=2

  lane=Conv2DTranspose(depth,(2,2),strides=(2,2),padding='same')(lane)
  # lane=UpSampling2D(size=(2,2), interpolation='nearest')(lane)

  lane=Conv2D(output_k,(7,7),activation='tanh',padding='same')(lane)

  net=Model(input,lane)

  return net

In [None]:
#@title Gaussian Blur Block

def gaussian_blur(img, kernel_size=5, sigma=5):
    def gauss_kernel(channels, kernel_size, sigma):
        ax = tf.range(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
        xx, yy = tf.meshgrid(ax, ax)
        kernel = tf.exp(-(xx ** 2 + yy ** 2) / (2.0 * sigma ** 2))
        kernel = kernel / tf.reduce_sum(kernel)
        kernel = tf.tile(kernel[..., tf.newaxis], [1, 1, channels])
        return kernel

    gaussian_kernel = gauss_kernel(tf.shape(img)[-1], kernel_size, sigma)
    gaussian_kernel = gaussian_kernel[..., tf.newaxis]

    return tf.nn.depthwise_conv2d(img, gaussian_kernel, [1, 1, 1, 1],
                                  padding='SAME', data_format='NHWC')

In [None]:
#@title Model Composite for Generator Training

def Train_g_composite(g_1,g_2,d_1,d_2,image_shape,opt):
  
  ##g_1: the generator that translate photo from style 1 to style 2
  ##g_2: the generator that translate photo from style 2 to style 1
  ##d_1: the discriminator that identify photo with style 1 from others
  ##d_1: the discriminator that identify photo with style 1 from others

  #setting trainability

  g_1.trainable=True
  g_2.trainable=True

  d_1.trainable=False
  d_2.trainable=False

  ##input_1: input of photo of style 1
  ##input_1: input of photo of style 2

  input_1=Input(shape=image_shape)
  input_2=Input(shape=image_shape)

  #image_translated to the alternative style
  ##alter_1_2: image transfered from style 1 to style 2
  ##alter_1_2: image transfered from style 2 to style 1

  alter_1_2=g_1(input_1)
  alter_2_1=g_2(input_2)

  #translated image traslted back to their original style
  ##cycle_1: image that is orinally style 1
  ##cycle_2: image that is orinally style 2

  cycle_1=g_2(alter_1_2)
  cycle_2=g_1(alter_2_1)

  #identical translation(expecting no change)
  ##idt_1: image with style 1 after identical transformation
  ##idt_2: image with style 2 after identical transformation

  idt_1=g_2(input_1)
  idt_2=g_1(input_2)

  #calculate adversial loss
  ##ad_1: adversial loss of generator 1
  ##ad_2: adversial loss of generator 2

  ad_2=d_2(alter_1_2)
  ad_1=d_1(alter_2_1)

  # net=Model([input_1,input_2], [ad_1, ad_2, cycle_1, cycle_2, idt_1, idt_2])
  # net.compile(loss=['binary_crossentropy', 'binary_crossentropy', 'mae', 'mae', 'mae', 'mae'], loss_weights=[5,5,10,10,1,1], optimizer=opt)

  net=Model([input_1,input_2], [ad_1, ad_2, cycle_1, cycle_2])
  net.compile(loss=['mse', 'mse', 'mae', 'mae'], loss_weights=[1,1,1,1], optimizer=opt)

  # net=Model([input_1,input_2], [ad_1, ad_2, minus_1, minus_2])
  # net.compile(loss=['mse', 'mse', 'mae', 'mae'], loss_weights=[2,2,10,10], optimizer=opt)

  return net

In [None]:
#@title Model Composite for Discriminator Training

def Train_d_composite(g_1,g_2,d_1,d_2,image_shape,opt):
  #setting trainability

  g_1.trainable=False
  g_2.trainable=False

  d_1.trainable=True
  d_2.trainable=True

  #true image input
  ##input_1: true image of style 1
  ##input_2: true image of style 2

  input_1=Input(shape=image_shape)
  input_2=Input(shape=image_shape)

  #translate true image
  ##alter_1_2: translated image from style 1 to 2
  ##alter_2_1: translated image from style 2 to 1

  alter_1_2=g_1(input_1)
  alter_2_1=g_2(input_2)

  #evaluate true image

  real_1=d_1(input_1)
  real_2=d_2(input_2)

  #evaluate fake image

  fake_1=d_1(alter_2_1)
  fake_2=d_2(alter_1_2)

  net=Model([input_1,input_2], [real_1, real_2, fake_1, fake_2])
  net.compile(loss=['mse', 'mse', 'mse', 'mse'], optimizer=opt, metrics=['accuracy'])
  # net.compile(loss=['binary_crossentropy', 'binary_crossentropy', 'binary_crossentropy', 'binary_crossentropy'], optimizer=opt, metrics=['accuracy'])

  return net

In [None]:
#@title Image Preprocessing & Preparation

def to_binary(image):
  image=np.squeeze(image)

  for i in range(image.shape[0]):
    for ii in range(image.shape[1]):
      for iii in range(image.shape[2]):
        if image[i][ii][iii]>=0.5:
          image[i][ii][iii]=1
        else:
          image[i][ii][iii]=0

  return image

def image_encode(x):
  return (x-127.5)/127.5

def image_decode(x):
  return ((x+1)*127.5).astype('uint8')

def load_image(num,HED_path,Sobel_path,ratio=0.4,original_path='',gray=False):
  img_names=os.listdir(HED_path)
  album=[]
  index=np.random.randint(len(img_names),size=num)

  for ii in index:
    img_name=img_names[ii]

    HED_img_path=os.path.join(HED_path, img_name)
    HED_img=Image.open(HED_img_path)
    if gray:
      HED_img = HED_img.convert('L') 
    HED_img=np.array(HED_img)

    Sobel_img_path=os.path.join(Sobel_path, img_name)
    Sobel_img=Image.open(Sobel_img_path)
    if gray:
      Sobel_img = Sobel_img.convert('L') 
    Sobel_img=np.array(Sobel_img)

    img=ratio*Sobel_img+(1-ratio)*HED_img
    img=np.expand_dims(img,-1)

    img=image_encode(img)
    album.append(img)
  
  album=np.array(album)

  if original_path!='':
    org_album=[]
    for ii in index:
      img_name=img_names[ii]
      img_path=os.path.join(original_path, img_name)
      img=Image.open(img_path)
      if img.mode!='RGB':
        img=img.convert('RGB')
      if gray:
        img = img.convert('L') 
      img=np.array(img)
      img=image_encode(img)
      org_album.append(img)
    
    org_album=np.array(org_album)

    return album, org_album
  
  else:
    return album

def specific_load_image(index,HED_path,Sobel_path,ratio=0.4,gray=False):
  album=[]

  for ii in index:
    img_name=str(ii)+'.png'

    HED_img_path=os.path.join(HED_path, img_name)
    HED_img=Image.open(HED_img_path)
    if gray:
      HED_img = HED_img.convert('L') 
    HED_img=np.array(HED_img)

    Sobel_img_path=os.path.join(Sobel_path, img_name)
    Sobel_img=Image.open(Sobel_img_path)
    if gray:
      Sobel_img = Sobel_img.convert('L') 
    Sobel_img=np.array(Sobel_img)

    img=ratio*Sobel_img+(1-ratio)*HED_img
    img=np.expand_dims(img,-1)

    img=image_encode(img)
    album.append(img)
  
  album=np.array(album)

  return album

def save_image(index,album,path,gray=False):
  album=image_decode(album)

  save_dir=path+ '/Round_{}'.format(index)

  if not os.path.exists(save_dir):
    os.makedirs(save_dir)

  for ii, photo in enumerate(album):
    if gray:
      photo=np.squeeze(photo)
    img=Image.fromarray(photo.astype('uint8'))
    img.save(save_dir+'/{}.png'.format(str(ii)))

In [None]:
#@title Output Data Generation

def create_zero_patch(num,length):
  return np.random.rand(num,length,length,1)/10

def create_one_patch(num,length):
  return np.ones((num,length,length,1))-np.random.rand(num,length,length,1)/10

In [None]:
#@title Training

def Train_CycleGAN(g_1, g_2, d_1, d_2, load_path_1_HED, load_path_1_Sobel, load_path_2_HED, load_path_2_Sobel, save_path_1, save_path_2, image_shape, gd_ratio=2, pred_index_1=None, pred_index_2=None, save_us=True, reload_multi=32, if_gray=False):
  
  print('---trace on---')
  
  img_count=0
  base_num=128

  opt_1=adam_v2.Adam(learning_rate=2e-4, beta_1=0.5)
  opt_2=adam_v2.Adam(learning_rate=1e-4, beta_1=0.5)

  g_p=load_model('', custom_objects={"InstanceNormalization": InstanceNormalization})

  batch_size=4

  total_round=4096
  save_img_round=8
  save_net_round=128
  paint_round=16
  d_val_round=32

  painter_loss_min=100

  reload_num=reload_multi*batch_size
  reload_ii=-1

  train_g=Train_g_composite(g_1,g_2,d_1,d_2,image_shape,opt_1)
  train_d=Train_d_composite(g_1,g_2,d_1,d_2,image_shape,opt_2)

  patch_size=d_1.output_shape[1]

  print('---load images---')

  album_g_1=load_image(base_num,load_path_1_HED, load_path_1_Sobel,gray=if_gray)
  print('album_g_1 loaded')
  album_g_2=load_image(base_num,load_path_2_HED, load_path_2_Sobel,gray=if_gray)
  print('album_g_2 loaded')

  album_d_1=load_image(base_num,load_path_1_HED, load_path_1_Sobel,gray=if_gray)
  print('album_d_1 loaded')
  album_d_2=load_image(base_num,load_path_2_HED, load_path_2_Sobel,gray=if_gray)
  print('album_d_2 loaded')

  pred_album_1=specific_load_image(pred_index_1,load_path_1_HED, load_path_1_Sobel,gray=if_gray)
  pred_album_2=specific_load_image(pred_index_2,load_path_2_HED, load_path_2_Sobel,gray=if_gray)

  print('---begin training---')

  patch_one_d =create_one_patch(batch_size,patch_size)
  patch_zero_d=create_zero_patch(batch_size,patch_size)

  patch_one_g =np.ones((batch_size,patch_size,patch_size,1))

  for ii in range(total_round):
    # print(ii)
    if ii%reload_multi==0:
      new_album_g_1=load_image(reload_num,load_path_1_HED, load_path_1_Sobel,gray=if_gray)
      new_album_g_2=load_image(reload_num,load_path_2_HED, load_path_2_Sobel,gray=if_gray)

      new_album_d_1=load_image(reload_num,load_path_1_HED, load_path_1_Sobel,gray=if_gray)
      new_album_d_2=load_image(reload_num,load_path_2_HED, load_path_2_Sobel,gray=if_gray)

      album_g_1=np.concatenate([album_g_1[reload_num:], new_album_g_1], axis=0)
      album_g_2=np.concatenate([album_g_2[reload_num:], new_album_g_2], axis=0)

      album_d_1=np.concatenate([album_d_1[reload_num:], new_album_d_1], axis=0)
      album_d_2=np.concatenate([album_d_2[reload_num:], new_album_d_2], axis=0)

      reload_ii+=1
    data_g_1=album_g_1[ii*batch_size-reload_ii*reload_num:(ii+1)*batch_size-reload_ii*reload_num]
    data_g_2=album_g_2[ii*batch_size-reload_ii*reload_num:(ii+1)*batch_size-reload_ii*reload_num]

    data_d_1=album_d_1[ii*batch_size-reload_ii*reload_num:(ii+1)*batch_size-reload_ii*reload_num]
    data_d_2=album_d_2[ii*batch_size-reload_ii*reload_num:(ii+1)*batch_size-reload_ii*reload_num]

    # feedback_1=train_g.train_on_batch([data_g_1,data_g_2], [patch_one_g, patch_one_g, data_g_1, data_g_2, data_g_1, data_g_2])
    feedback_1=train_g.train_on_batch([data_g_1,data_g_2], [patch_one_g, patch_one_g, data_g_1, data_g_2])
    # feedback_1=train_g.train_on_batch([data_g_1,data_g_2], [patch_one_g, patch_one_g, minus_zero_g, minus_zero_g])

    if ii%gd_ratio==0:  
      feedback_2=train_d.train_on_batch([data_d_1,data_d_2], [patch_one_d, patch_one_d, patch_zero_d, patch_zero_d])

      print('Round {}: Generator Loss: {} Dircriminator Loss: {}'.format(ii, feedback_1[0], feedback_2[0]))
    
    else:
      print('Round {}: Generator Loss: {}'.format(ii, feedback_1[0]))

    if ii%save_img_round==0:
      if save_us:
        mimic_1=g_2.predict(pred_album_2)
        mimic_paint=g_p.predict(mimic_1)
        save_image(ii,mimic_paint,save_path_1,if_gray)
        save_image(ii,mimic_1,'/content/another_save',True)
        # save_image(ii,g_1.predict(pred_album_1),'/content/another_another_save',True)

    if ii%d_val_round==0:
      real_fb_1, real_fb_2, fake_fb_1, fake_fb_2=train_d.predict([data_g_1, data_g_2])

      patch_sum=batch_size*patch_size*patch_size

      real_fb_1=to_binary(real_fb_1)
      real_fb_2=to_binary(real_fb_2)

      fake_fb_1=to_binary(fake_fb_1)
      fake_fb_2=to_binary(fake_fb_2)

      real_accuracy_1=np.sum(real_fb_1)/patch_sum
      real_accuracy_2=np.sum(real_fb_2)/patch_sum

      fake_accuracy_1=1-np.sum(fake_fb_1)/patch_sum
      fake_accuracy_2=1-np.sum(fake_fb_2)/patch_sum

      d_accuracy=(real_accuracy_1+real_accuracy_2+fake_accuracy_1+fake_accuracy_2)/4

      print('\n')
      print('Discriminator accuracy: ', d_accuracy)
      print('\n')


    if ii%save_net_round==0:
      g_2_dir='/content/save_model'
      g_2_name='g_2_{}__{}.h5'.format(ii,feedback_1[0])
      g_2_path=os.path.join(g_2_dir, g_2_name)
      g_2.save(g_2_path)

In [None]:
image_shape=(256,256,1)

opt=adam_v2.Adam(learning_rate=2e-4, beta_1=0.5)

g_1=pp2_generator(image_shape, 1, 0.1,depth=64)
g_2=pp2_generator(image_shape, 1, 0.1,depth=64)
d_1=rg_patch_d(image_shape,adam_v2.Adam(learning_rate=2e-5, beta_1=0.5),lk_alpha=0.05, k_size=(5,5), init_k=32)
d_2=rg_patch_d(image_shape,adam_v2.Adam(learning_rate=2e-5, beta_1=0.5),lk_alpha=0.05, k_size=(5,5), init_k=32)

load_path_1_HED=r'/content/data/Painting_HED'
load_path_1_Sobel=r'/content/data/Painting_Sobel'
load_path_2_HED=r'/content/data/Photo_HED'
load_path_2_Sobel=r'/content/data/Photo_Sobel'

save_path_1=r''
save_path_2=r''

Train_CycleGAN(g_1, g_2, d_1, d_2, load_path_1_HED, load_path_1_Sobel, load_path_2_HED, load_path_2_Sobel, save_path_1, save_path_2, image_shape, gd_ratio=4, pred_index_1=[37,46,84,115,132,172,181,183,211,217,262,265,282,337], pred_index_2=[11,37,46,84,115,132,172,181,183,211,217,262,265,282,337])

# In Development: from image to ink painting pipeline

In [None]:
def to_ink(input):
  scaled_input = r"/content/sample_data/scaled_input.png"
  ratio = 0.4
  maxsize = (256, 256)

  input_img=Image.open(input)
  input_img = crop_center_square(np.array(input_img))
  input_img=Image.fromarray(input_img.astype('uint8'))
  input_img = input_img.resize(maxsize, Image.ANTIALIAS)
  input_img.save(scaled_input)

  cv_sobel(scaled_input, r"/content/sample_data/sobel.png")
  hed(scaled_input, r"/content/sample_data/hed.png")

  HED_img_path=r"/content/sample_data/hed.png"
  HED_img=Image.open(HED_img_path)

  Sobel_img_path=r"/content/sample_data/sobel.png"
  Sobel_img=Image.open(Sobel_img_path)

  Sobel_img = np.array(Sobel_img)
  HED_img = np.array(HED_img)

  img=ratio*Sobel_img+(1-ratio)*HED_img
  img=np.expand_dims(img,-1)
  img=np.expand_dims(img,0)

  img=image_encode(img)

  img=g_2.predict(img)
  output=g_p.predict(img)
  output = image_decode(output)
  output = np.squeeze(output)

  output=Image.fromarray(output.astype('uint8'))
  # output.save(r'/content/sample_data'+'/output.png')
  # display(output)

  return output

In [None]:
!git clone https://github.com/sniklaus/pytorch-hed.git
%cd pytorch-hed
!ls
%cd ..
!bash download.bash

fatal: destination path 'pytorch-hed' already exists and is not an empty directory.
/content/pytorch-hed
comparison  images  LICENSE  README.md	requirements.txt  run.py
/content
bash: download.bash: No such file or directory


In [None]:
import numpy as np
import cv2
import os

def cv_sobel(in_path, out_path):
  image = cv2.imread(in_path)
  image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)

  sobelX = cv2.Sobel(image,cv2.CV_64F,1,0)
  sobelY = cv2.Sobel(image,cv2.CV_64F,0,1)

  sobelX = np.uint8(np.absolute(sobelX))
  sobelY = np.uint8(np.absolute(sobelY))

  sobelCombined = cv2.bitwise_or(sobelX,sobelY)#

  cv2.imwrite(out_path, sobelCombined)

def hed(in_path, out_path):
  os.system("python /content/pytorch-hed/run.py --model bsds500 --in \'%s\' --out \'%s\'" %(in_path, out_path))

def crop_center_square(a):
  i = a.shape[0]
  j = a.shape[1]

  if i >= j:
    output = a[(i - j)//2 : (i - j)//2 + j]
  else:
    output = a[:, (j - i) // 2: (j - i) // 2 + i]

  return output