In [None]:
from matplotlib import pyplot as plt
from encoders.cgan import (
    load_imgs, 
    preprocess_data, 
    define_discriminator, 
    define_generator, 
    define_gan, 
    train
    )
import config as cfg
from datetime import datetime
from encoders.stn import Localization, BilinearInterpolation

In [None]:

backgrounds = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/background/")
paired = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/composite/")
objects = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/foreground/")
depth = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/depth/")
print('Loaded train: ', backgrounds.shape, paired.shape, objects.shape, depth.shape)

# Define data
data = [backgrounds, paired, objects, depth]
dataset = preprocess_data(data)

# TODO!! MAKE YOUR OWN SPLIT IF YOU WANT TO VALIDATE!
# currently it is using the train set so its useless.
if cfg.USE_VAL:
    val_backgrounds = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/background/")
    val_paired = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/composite/")
    val_objects = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/foreground/")
    val_depth = load_imgs(f"data/{cfg.TRAIN_DATASET}/train/depth/")
    print('Loaded validation: ', val_backgrounds.shape, val_paired.shape, val_objects.shape, val_depth.shape)
    val_data = [val_backgrounds, val_paired, val_objects, val_depth]
    val_dataset = preprocess_data(val_data)

In [None]:
n_samples = 3
n_rows = 4

for i in range(n_samples):
    # Plot background images
    plt.subplot(n_rows, n_samples, 1 + i)
    plt.axis('off')
    plt.imshow(backgrounds[i].astype('uint8'))

for i in range(n_samples):
    # Plot target images (paired)
    plt.subplot(n_rows, n_samples, 1 + n_samples + i)
    plt.axis('off')
    plt.imshow(depth[i].astype('uint8'))

for i in range(n_samples):
    # Plot target images (paired)
    plt.subplot(n_rows, n_samples, 1 + 2 * n_samples + i)
    plt.axis('off')
    plt.imshow(objects[i].astype('uint8'))

for i in range(n_samples):
    plt.subplot(n_rows, n_samples, 1 + 3 * n_samples + i)
    plt.axis('off')
    plt.imshow(paired[i].astype('uint8'))

plt.show()

In [None]:
if cfg.STN_CHECK:
	processed = (objects - 127.5) / 127.5
	theta = Localization()(processed)
	print(theta)
	x = BilinearInterpolation(height=256, width=256)([processed, theta])

	# Display the original and transformed images
	plt.figure(figsize=(10, 4))
	# Display the original image
	plt.subplot(1, 3, 1)
	original_img = processed[0]
	plt.imshow(original_img)
	plt.title('Original Image')
	plt.axis('off')
	# Display the transformed images
	plt.subplot(1, 3, 2)
	transformed_img = x[0]
	plt.imshow(transformed_img)
	plt.title('Transformed Image')
	plt.axis('off')
	plt.subplot(1, 3, 3)
	transformed_img = x[1]
	plt.imshow(transformed_img)
	plt.title('Transformed Image')
	plt.axis('off')
	plt.show()

In [None]:
# define input shape based on the loaded dataset
image_shape = backgrounds.shape[1:]
# define the models
d_model = define_discriminator(image_shape, cfg.D_LR)
g_model = define_generator(image_shape, cfg.USE_STN)
# define the composite model
gan_model = define_gan(g_model, d_model, image_shape, cfg.G_LR)
g_model.summary()

In [None]:
start = datetime.now()
if cfg.USE_VAL:
    train(d_model, g_model, gan_model, dataset, val_dataset)
else:
    train(d_model, g_model, gan_model, dataset)

stop = datetime.now()
#Execution time of the model
execution_time = stop-start
print("Execution time is: ", execution_time)