In [None]:
import numpy as np
from matplotlib import pyplot
from encoders.cgan import *
from datetime import datetime
from encoders.stn import Localization, BilinearInterpolation

In [None]:
np.random.seed(10)
backgrounds = load_imgs("data/bottle/train/background/")
paired = load_imgs("data/bottle/train/composite/")
objects = load_imgs("data/bottle/train/foreground/")
depth = load_imgs("data/bottle/train/depth/")
print('Loaded train: ', backgrounds.shape, paired.shape, objects.shape)

# Define data
data = [backgrounds, paired, objects, depth]
dataset = preprocess_data(data)

'''
val_backgrounds = load_imgs("data/bottle/val/background/")
val_paired = load_imgs("data/bottle/val/composite/")
val_objects = load_imgs("data/bottle/val/foreground/")
val_depth = load_imgs("data/bottle/val/depth/")
print('Loaded validation: ', val_backgrounds.shape, val_paired.shape, val_objects.shape)

val_data = [val_backgrounds, val_paired, val_objects, val_depth]
val_dataset = preprocess_data(val_data)
'''

In [None]:
n_samples = 3
n_rows = 4

for i in range(n_samples):
    # Plot background images
    pyplot.subplot(n_rows, n_samples, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(backgrounds[i].astype('uint8'))

for i in range(n_samples):
    # Plot target images (paired)
    pyplot.subplot(n_rows, n_samples, 1 + n_samples + i)
    pyplot.axis('off')
    pyplot.imshow(paired[i].astype('uint8'))

for i in range(n_samples):
    # Add another axis of images (replace 'additional_images' with the actual array)
    pyplot.subplot(n_rows, n_samples, 1 + 3 * n_samples + i)
    pyplot.axis('off')
    pyplot.imshow(objects[i].astype('uint8'))  # Replace 'additional_images' with your array

pyplot.show()

In [None]:
# Hyper Parameters
VAL_SAMPLES = 3 # Number of validation samples
VAL_FREQUENCY = 70000 # Perform validation every x step
EPOCHS = 1000 # Number of epochs
BATCH_SIZE = 1 # Batch size
D_LR = 0.0001 # Discriminator learning rate
G_LR = 0.0002 # Generator learning rate

In [None]:
# STN Sanity check
STN_CHECK = False

if STN_CHECK:
	processed = (objects - 127.5) / 127.5
	theta = Localization()(processed)
	print(theta)
	x = BilinearInterpolation(height=256, width=256)([processed, theta])

	# Display the original and transformed images
	plt.figure(figsize=(10, 4))
	# Display the original image
	plt.subplot(1, 3, 1)
	original_img = processed[0]  # Assuming x is in the range [-1, 1]
	plt.imshow(original_img)
	plt.title('Original Image')
	plt.axis('off')
	# Display the transformed image
	plt.subplot(1, 3, 2)
	transformed_img = x[0]  # Assuming output is in the range [-1, 1]
	plt.imshow(transformed_img)
	plt.title('Transformed Image')
	plt.axis('off')
	plt.subplot(1, 3, 3)
	transformed_img = x[1]  # Assuming output is in the range [-1, 1]
	plt.imshow(transformed_img)
	plt.title('Transformed Image')
	plt.axis('off')
	plt.show()

In [None]:
# define input shape based on the loaded dataset
image_shape = backgrounds.shape[1:]
# define the models
d_model = define_discriminator(image_shape, D_LR)
g_model = define_generator(image_shape, False)
# define the composite model
gan_model = define_gan(g_model, d_model, image_shape, G_LR)
g_model.summary()

In [None]:
start = datetime.now()
train(d_model, g_model, gan_model, dataset, VAL_SAMPLES, VAL_FREQUENCY, EPOCHS, BATCH_SIZE)

stop = datetime.now()
#Execution time of the model
execution_time = stop-start
print("Execution time is: ", execution_time)