# Image to Image Translation using Pix2Pix

In [13]:
import tensorflow as tf
import sys
import time
import os
import random
from keras.optimizers import Adam
from keras.backend import clear_session

In [14]:
# add folders data_processing and model to path so that we can read in our own modules from this folder
sys.path.append("data_processing")
sys.path.append("model")

# import our own methods 
from load_datasets import load_train_dataset, load_test_dataset
from model_setup import generator, discriminator
from training_methods import train_step  # defined under model/training_methods.py 
from visualize_data import plot_images_at_epoch  # defined under data_processinig/visualize_data.py 

In [15]:
devices = tf.config.experimental.list_physical_devices("CPU")
print(devices)
#tf.config.experimental.set_memory_growth(devices[0] ,enable=True)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


In [16]:
IMAGE_SIZE = 128  # do NOT change

# path to project directory
PROJECT_DIR = "/net/merisi/pbigalke/teaching/METFUT2024/CGAN_Pix2Pix_MSG"

## path to train and val image sets
# DATASET_PATH = f"{PROJECT_DIR}/VIS_IR_images"
TRAIN_PATH = f"{PROJECT_DIR}/VIS_IR_images/train"
TEST_PATH = f"{PROJECT_DIR}/VIS_IR_images/val"

# path where to store output
OUT_PATH = f"{PROJECT_DIR}/output/training_IR_VIS"
if not os.path.exists(OUT_PATH):
    os.makedirs(OUT_PATH)

In [17]:
# load and preprocess dataset
BATCH_SIZE = 10

# get all training data
train_dataset = load_train_dataset(TRAIN_PATH, BATCH_SIZE)
print(train_dataset)

# get all test data
test_dataset = load_test_dataset(TEST_PATH, BATCH_SIZE)
print(test_dataset)

<_BatchDataset element_spec=(TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))>
<_BatchDataset element_spec=(TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))>


In [18]:
clear_session()

# set up generator
gen_model = generator()
gen_model.summary()

# set up discriminator
discr_model = discriminator()
discr_model.summary()

# create optimizers for generator and discriminator
gen_optimizer = Adam(lr=2e-4, beta_1=0.5)
discr_optimizer = Adam(lr=2e-4, beta_1=0.5)


input layer (None, 128, 128, 1)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128, 128, 1)]        0         []                            
                                                                                                  
 sequential (Sequential)     (None, 64, 64, 64)           1024      ['input_1[0][0]']             
                                                                                                  
 sequential_1 (Sequential)   (None, 32, 32, 128)          131584    ['sequential[0][0]']          
                                                                                                  
 sequential_2 (Sequential)   (None, 16, 16, 256)          525312    ['sequential_1[0][0]']        
                                                              



In [19]:
# define training procedure
def fit(generator, discriminator, gen_optimizer, discr_optimizer, train_dataset, test_dataset, epochs, 
        outpath_img=None, outpath_model=None, save_after_epochs=5):

    start_training = time.time()

    # loop over number of epochs
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}")
        start_epoch = time.time()

        # loop over batches in training dataset
        for n, (input_image, target) in train_dataset.enumerate():
            # perform training step
            gen_loss, disc_loss = train_step(input_image, target, generator, discriminator, 
                                             gen_optimizer, discr_optimizer)
            
        # save example images of this epoch
        if outpath_img is not None:
            # select random batch of test dataset
            for ir_batch, vis_batch in test_dataset.take(1):

                # predict vis img from ir with generator
                predict_vis_batch = generator(ir_batch, training=True)

                 # select random image of batch
                rand_idx = random.randint(0, BATCH_SIZE-1)
                ir_img = ir_batch.numpy()[rand_idx]
                vis_img = vis_batch.numpy()[rand_idx]
                predict_vis_img = predict_vis_batch.numpy()[rand_idx]

                plot_images_at_epoch(ir_img, predict_vis_img, vis_img, 
                                    output_file=f"{outpath_img}/example_images_epoch{epoch}_train.png", 
                                    normalized=True)
                
                predict_vis_batch = generator(ir_batch, training=False)
                predict_vis_img = predict_vis_batch.numpy()[rand_idx]
                plot_images_at_epoch(ir_img, predict_vis_img, vis_img, 
                                    output_file=f"{outpath_img}/example_images_epoch{epoch}_NOtrain.png", 
                                    normalized=True)
                
        # save model state
        if epoch % save_after_epochs == 0:
            print("TODO: implement saving of model")

        # calculate the test and print it:
        print("TODO: implement test loss, accuracy etc.")
        
        # print some information on the progress of training
        print("TODO: save losses in an array and return for later plotting")
        print(f"Generator loss: {gen_loss:.2f}, Discriminator loss: {disc_loss:.2f}")
        print(f"Time for epoch {epoch+1}: {(time.time()-start_epoch)/60.:.2f} min.")
        print(f"Total runtime: {(time.time()-start_training)/60.:.2f} min.")


In [21]:
import matplotlib.pyplot as plt

def plot_images_at_epoch(ir_image, predict_image, vis_image, output_file=None, normalized=False):

    vmin = -1 if normalized else 0
    vmax = 1 if normalized else 255

    plt.figure(figsize = (15,15))
    display_list= [ir_image, predict_image, vis_image]
    title = ["IR image", "predicted VIS image", "true VIS image"]
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i], cmap="gray", vmin=vmin, vmax=vmax)
        plt.axis("off")
    # save image if output file is defined
    if output_file is not None:
        plt.savefig(output_file, bbox_inches="tight")
    plt.show()
    plt.close()

In [26]:
# do the training
epochs = 1
outpath_img = f"{OUT_PATH}/example_images"
if not os.path.exists(outpath_img):
    os.makedirs(outpath_img)
outpath_model = f"{OUT_PATH}/model_states"
if not os.path.exists(outpath_model):
    os.makedirs(outpath_model)

import numpy as np

# select random batch of test dataset
for ir_batch, vis_batch in test_dataset.take(1):

    # predict vis img from ir with generator
    predict_vis_batch = gen_model(ir_batch, training=True)

        # select random image of batch
    rand_idx = random.randint(0, BATCH_SIZE-1)
    ir_img = ir_batch.numpy()[rand_idx]
    vis_img = vis_batch.numpy()[rand_idx]
    predict_vis_img = predict_vis_batch.numpy()[rand_idx]

    print(np.min(ir_img), np.max(ir_img))
    print(np.min(vis_img), np.max(vis_img))
    print(np.min(predict_vis_img), np.max(predict_vis_img))
    print()

    #plot_images_at_epoch(ir_img, predict_vis_img, vis_img, 
    #                    output_file=f"{outpath_img}/example_images_epoch0_train.png", 
    #                    normalized=True)
                

#fit(gen_model, discr_model, gen_optimizer, discr_optimizer, train_dataset, test_dataset, epochs, 
#    outpath_img=outpath_img, outpath_model=outpath_model)

-0.9998462 -0.99640137
-0.99996924 -0.99215686
-0.7846868 0.7026362

