# Data generator
This notebook uses the simulated and real datasets to produce the generated images using 3 pix2pix and 3 CycleGAN networks.
It aslo produces the segmentation masks using Segformer.

## Usage:
### task_type='donkey' <- Choose task domain 'donkey' or 'kitti' for lane-keeping or vehicle detection
Run all cells

## Requirements:
-Segformer checkpoints for segmentation: \
    ./[task_type]/content/segmentation_checkpoints/Model_weights.hdf5 \
    
    Please refer to specific [task_type] to generate segmentation checkpoint

-CycleGAN and pix2pix checkpoints for I2I translation:\
    ./[task_type]/content/gan_checkpoints/cyclegan_checkpoints/1/ \
    ./[task_type]/content/gan_checkpoints/cyclegan_checkpoints/2/ \
    ./[task_type]/content/gan_checkpoints/cyclegan_checkpoints/3/ \
    ./[task_type]/content/gan_checkpoints/pix2pix_checkpoints/1/ \
    ./[task_type]/content/gan_checkpoints/pix2pix_checkpoints/2/ \
    ./[task_type]/content/gan_checkpoints/pix2pix_checkpoints/3/ 

    Please refer to specific [task_type] to generate GAN checkpoints

-Dataset h5 files: \
    KITTI: \
    ./[task_type]/content/datasets/h5_out/bounding_boxes_real.h5\
    ./[task_type]/content/datasets/h5_out/bounding_boxes_sim.h5\
    ./[task_type]/content/datasets/h5_out/raw_image_real.h5\
    ./[task_type]/content/datasets/h5_out/raw_image_sim.h5\
    ./[task_type]/content/datasets/h5_out/segmentation_masks_real.h5\
    ./[task_type]/content/datasets/h5_out/segmentation_masks_sim.h5\
    ./[task_type]/content/datasets/h5_out/semantic_id_list_real.h5\
    ./[task_type]/content/datasets/h5_out/semantic_id_list_sim.h5\
    DONKEY:\
    ./[task_type]/content/datasets/h5_out/gt_real.h5\
    ./[task_type]/content/datasets/h5_out/raw_image_real.h5\
    ./[task_type]/content/datasets/h5_out/raw_image_sim.h5\
    ./[task_type]/content/datasets/h5_out/semantic_id_list_real.h5\
    ./[task_type]/content/datasets/h5_out/semantic_id_list_sim.h5\

    Please refer to specific [task_type] to generate h5 files

## Outputs:
    ./[task_type]/content/output_plots/[domain_type]/[domain_type] 
    ./[task_type]/content/output_plots/[domain_type]/[domain_type]_mask 
    ./[task_type]/content/output_plots/[domain_type]/[domain_type]_additional_mask 
    ./[task_type]/content/output_plots/[domain_type]/[domain_type]_mask_error


In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import tensorflow as tf

import os
import matplotlib.pyplot as plt
import json
import tensorflow as tf
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import os
import matplotlib.pyplot as plt

import data_generator_utils

## Set parameters and task type

In [None]:
OUTPUT_CHANNELS = 3
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 512
IMG_HEIGHT = 512

%load_ext autoreload
%autoreload 2

task_type='donkey'
output_folder = './'+task_type+'/content/output_plots'
plot=False
two_classes=False
limit=10000

## Load necessary data and models

In [None]:
file_path = './'+task_type+'/content/datasets/h5_out/raw_image_sim.h5'
loaded_dictionary_images_sim=data_generator_utils.load_h5_to_dictionary(file_path)
file_path = './'+task_type+'/content/datasets/h5_out/raw_image_real.h5'
loaded_dictionary_images_real=data_generator_utils.load_h5_to_dictionary(file_path)
file_path = './'+task_type+'/content/datasets/h5_out/semantic_id_list_real.h5'
loaded_semantic_id_real=data_generator_utils.load_h5_to_dictionary(file_path)
file_path = './'+task_type+'/content/datasets/h5_out/semantic_id_list_sim.h5'
loaded_semantic_id_sim=data_generator_utils.load_h5_to_dictionary(file_path)

if task_type=="kitti":
    file_path = './'+task_type+'/content/datasets/h5_out/bounding_boxes_sim.h5'
    loaded_bounding_sim=data_generator_utils.load_h5_to_dictionary(file_path)
    file_path = './'+task_type+'/content/datasets/h5_out/bounding_boxes_real.h5'
    loaded_bounding_real=data_generator_utils.load_h5_to_dictionary(file_path)
    height,width=374,1238
    road=0
    additional_id=3
    additional_id_init=13
    dataset_index_list_test=["0001","0002","0006","0018","0020"]
    pattern = 'tvvttvttnn'
    
elif task_type=="donkey":
    height,width=140,320
    road=1
    additional_id=1
    additional_id_init=2
    dataset_index_list_test=["0001"]
    pattern = 'vvvvvvvvvv'
    
else:
    print("Choose a set")


train_indexes_gan,test_indexes_gan=data_generator_utils.get_gan_indexes(dataset_index_list_test,loaded_dictionary_images_real,loaded_dictionary_images_sim,pattern)
loaded_dictionary_images_real,loaded_dictionary_images_sim,loaded_semantic_id_real,loaded_semantic_id_sim=data_generator_utils.crop_data_dictionaries(task_type,dataset_index_list_test,loaded_dictionary_images_real,loaded_dictionary_images_sim,loaded_semantic_id_real,loaded_semantic_id_sim)

checkpoint_file_path = "./"+task_type+"/content/segmentation_checkpoints/Model_weights.hdf5"
segmentation_model=data_generator_utils.load_segmentation_model(checkpoint_file_path,task_type)

## Generate sim and real data
- Generate and populate sim and real image folders ('/content/output_plots/sim/sim and '/content/output_plots/real/real)
- Compute and populate semantic masks folders ('/content/output_plots/sim/sim_mask' and '/content/output_plots/real/real_mask')
- Compute and populate TSS and OC-TSS metric folders (reffered as semantic mask error in the codebase: '/content/output_plots/sim/sim_mask_error')

In [None]:
real_path_name="real"
sim_path_name="sim"
data_generator_utils.save_sim_real_outputs(task_type,segmentation_model,real_path_name,sim_path_name,additional_id,height,width,dataset_index_list_test,test_indexes_gan,loaded_dictionary_images_real,loaded_dictionary_images_sim)

## Generate CycleGAN data
- Generate and populate image folders ('/content/output_plots/cyclegan/cyclegan_id) for each cyclegan_id in checkpoint_names
- Compute and populate semantic masks folders ('/content/output_plots/cyclegan/cyclegan_id_mask')  for each cyclegan_id in checkpoint_names
- Compute and populate TSS and OC-TSS metric folders (reffered as semantic mask error in the codebase: '/content/output_plots/cyclegan/cyclegan_id_mask_error')  for each cyclegan_id in checkpoint_names

In [None]:
checkpoint_paths = ['./'+task_type+'/content/gan_checkpoints/cyclegan_checkpoints/1','./'+task_type+'/content/gan_checkpoints/cyclegan_checkpoints/2','./'+task_type+'/content/gan_checkpoints/cyclegan_checkpoints/3']
checkpoint_names=["cyclegan_1","cyclegan_2","cyclegan_3"]

for checkpoint_path,checkpoint_name in zip(checkpoint_paths,checkpoint_names):
    print(checkpoint_path)
    generator_cyclegan = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
    ckpt_cyclegan = tf.train.Checkpoint(generator_g=generator_cyclegan)
    ckpt_manager_cyclegan = tf.train.CheckpointManager(ckpt_cyclegan, checkpoint_path, max_to_keep=5)
    if ckpt_manager_cyclegan.latest_checkpoint:
        ckpt_cyclegan.restore(ckpt_manager_cyclegan.latest_checkpoint)
        print ('Checkpoint restored')
    else:
        print ('No Checkpoint! Check source path')
    print(height,width)
    data_generator_utils.save_cyclegan_outputs(task_type,segmentation_model,generator_cyclegan,checkpoint_name,additional_id,height,width,dataset_index_list_test,test_indexes_gan,loaded_dictionary_images_real,loaded_dictionary_images_sim)

## Generate pix2pix data
- Generate and populate image folders ('/content/output_plots/pix2pix/pix2pix_id) for each pix2pix_id in checkpoint_names
- Compute and populate semantic masks folders ('/content/output_plots/pix2pix/pix2pix_id_mask')  for each pix2pix_id in checkpoint_names
- Compute and populate TSS and OC-TSS metric folders (reffered as semantic mask error in the codebase: '/content/output_plots/pix2pix/pix2pix_id_mask_error')  for each pix2pix_id in checkpoint_names

In [None]:
checkpoint_paths = ['./'+task_type+'/content/gan_checkpoints/pix2pix_checkpoints/1','./'+task_type+'/content/gan_checkpoints/pix2pix_checkpoints/2','./'+task_type+'/content/gan_checkpoints/pix2pix_checkpoints/3']
checkpoint_names=["pix2pix_mask_1","pix2pix_mask_2","pix2pix_mask_3"]
    

input_domain="sim"
pix2pix_mask_type="manual"

for checkpoint_path,checkpoint_name in zip(checkpoint_paths,checkpoint_names):
    generator_pix2pix_mask_real = data_generator_utils.Generator()
    ckpt_pix2pix_mask_real = tf.train.Checkpoint(
                                     generator=generator_pix2pix_mask_real)
    ckpt_manager_pix2pix_mask_real = tf.train.CheckpointManager(ckpt_pix2pix_mask_real, checkpoint_path, max_to_keep=5)
    if ckpt_manager_pix2pix_mask_real.latest_checkpoint:
        ckpt_pix2pix_mask_real.restore(ckpt_manager_pix2pix_mask_real.latest_checkpoint)
        print ('Checkpoint restored')
    else:
        print ('No Checkpoint! Check source path')
    data_generator_utils.save_pix2pix_mask_outputs(task_type,limit,segmentation_model,two_classes,generator_pix2pix_mask_real,checkpoint_name,additional_id,height,width,input_domain,pix2pix_mask_type,dataset_index_list_test, test_indexes_gan, loaded_dictionary_images_real,loaded_dictionary_images_sim,loaded_semantic_id_real,loaded_semantic_id_sim)