### 1. Import dependencies


In [None]:
import discriminator_model as D, generator_model as G
import datagenerator as DTG
import augmentation

import config
import training_loop as TL
import visualization

import callbacks

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

### 2. Create models

In [None]:
INPUT_SHAPE = (256,256,3)

dis_filt_list = [16,32,64, 128]
dis_filt_list = [32,64,128,192,256, 356]

dis_striding_list = [1,2,1,2,1, 1]

lr_gen = 1.5e-4
lr_dis = 1e-4

dis1 = D.CreateDiscriminator(input_shape=INPUT_SHAPE, 
                        filters=dis_filt_list, 
                        striding=dis_striding_list, 
                        ksize=4)
dis2 = D.CreateDiscriminator(input_shape=INPUT_SHAPE, 
                        filters=dis_filt_list, 
                        striding=dis_striding_list, 
                        ksize=4)

gen1 = G.CreateGenerator(input_shape=INPUT_SHAPE, 
                    filters=48, 
                    residual_blocks=6, 
                    ksize=4)

gen2 = G.CreateGenerator(input_shape=INPUT_SHAPE, 
                    filters=48, 
                    residual_blocks=6, 
                    ksize=4)





g1_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_gen, beta_1=0.6)
d1_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_dis, beta_1=0.6)
g2_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_gen, beta_1=0.6)
d2_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_dis, beta_1=0.6)


gen1.compile()
gen2.compile()
dis1.compile()
dis2.compile()


    
    
gen1.summary()
dis1.summary()
 

### 4. Datagenerator


In [None]:
def to_grayscale(img):
    out = np.mean(img, axis=-1)
    return out.reshape((1, *INPUT_SHAPE))

In [None]:

#datagen = DTG.TestMNISTGeneator(0, 4)

#datagen = DTG.CatDogsCIFARGenerator(augmentation=False)


pa = 'path_to_images_x1'
pb = 'path_to_images_x2'


fill_mode = 'nearest'
rand_aug = augmentation.RandAugmentation(vshift=0.05, # 5% vertical random shift
                                         hshift=0.05, # 5% horizontal random shift
                                         zoom=0.02, # 2% (+-1%) random zoom
                                         rotate=10, # +-10deg random rotate
                                         hflip=True,  # rand 50% horizontal flip 
                                         normalize=True,  # normalize
                                         fill_strategy=fill_mode, # fill with mirror/nearest) 
)


no_aug = augmentation.AugmentationUnit()


datagen = DTG.ImgFileIterator(pa, pb, INPUT_SHAPE, rand_aug, fill_mode=fill_mode)


#datagen = DTG.TestMNISTGeneator(6, 3, rand_aug)

print(len(datagen))
x1, x2 = datagen.__getitem__(0)
print(x1.shape, x2.shape)

Look for generated images

In [None]:
CNT = 1
arr = []
for i in range(CNT ** 2):
    print(i+1, '/', CNT**2, '       ', end='\r')
    x1, x2 = datagen.__getitem__(i)
    arr.append(x1[0, :])
    arr.append(x2[0, :])

plt.imshow(visualization.pack_into_array3dim(np.array(arr), (CNT, 2*CNT)), cmap='gray')

In [None]:
hist_callback = callbacks.HistoryLossesCallback(['generators_loss', 'discriminators_loss'])
runtime_visual_callback = callbacks.DynamicGenOutputCallback(datagen, scale=1)
wsaver_callback = callbacks.WeightsSaveCallback('CheckPoints\model1', frequency=1)

### 5. Fit

In [None]:

history = TL.fit(datagen, gen1, gen2, dis1, dis2,
 g1_optimizer, g2_optimizer, d1_optimizer, d2_optimizer, 100, [hist_callback, runtime_visual_callback, wsaver_callback])



### 6. Learning statistic

In [None]:
#visualization.plot_history(history)
hist_callback.plot()
hist_callback.get_size()

In [None]:
wsaver_callback.load_weights([gen1, gen2, dis1, dis2], eph=None)

### 7. Overview of generators output

In [None]:

H, W = 2,2
SAMPLES = H*W

# inital data
dlen=len(datagen)
x1_pack = np.squeeze([datagen.__getitem__(np.random.randint(0, dlen))[0] for i in range(SAMPLES)])
x2_pack = np.squeeze([datagen.__getitem__(np.random.randint(0, dlen))[1] for i in range(SAMPLES)])

# generated
gx2 = np.squeeze(gen1(x1_pack))
gx1 = np.squeeze(gen2(x2_pack))

# cycle-geenrated
cx1 = np.squeeze(gen2(gen1(x1_pack)))
cx2 = np.squeeze(gen1(gen2(x2_pack)))

#identity
ix1 = np.squeeze(gen2(x1_pack))
ix2 = np.squeeze(gen1(x2_pack))


# figs
fig, axs = plt.subplots(2, 4, figsize=(24,12), dpi=200)

axs[0, 0].imshow(visualization.pack_into_array3dim(x1_pack, (H, W)), cmap='gray')
axs[0, 1].imshow(visualization.pack_into_array3dim(gx2, (H, W)), cmap='gray')
axs[0, 2].imshow(visualization.pack_into_array3dim(x2_pack, (H, W)), cmap='gray')
axs[0, 3].imshow(visualization.pack_into_array3dim(gx1, (H, W)), cmap='gray')

axs[1, 0].imshow(visualization.pack_into_array3dim(cx1, (H, W)), cmap='gray')
axs[1, 1].imshow(visualization.pack_into_array3dim(cx2, (H, W)), cmap='gray')
axs[1, 2].imshow(visualization.pack_into_array3dim(ix1, (H, W)), cmap='gray')
axs[1, 3].imshow(visualization.pack_into_array3dim(ix2, (H, W)), cmap='gray')


axs[0, 0].set_title('Input data X1')
axs[0, 1].set_title('Generator 1 out')
axs[0, 2].set_title('Input data X2')
axs[0, 3].set_title('Generator 2 out')

axs[1, 0].set_title('Cycle out gx1 = gen2(gen1(x1))')
axs[1, 1].set_title('Cycle out gx2 = gen1(gen2(x2))')
axs[1, 2].set_title('Identity out gen2(x1)')
axs[1, 3].set_title('Identity out gen1(x2)')

fig.tight_layout()
for axi in axs: 
    for ax in axi: ax.grid(False)
