#  Run Pix2Pix
This notebook is optimized for using pytorch (look at the environment on the top right). <br>
This is the main folder path: [~/ml/](http://localhost:8888/tree/ml)<br>
Image dataset is located here: [~/ml/dataset_oct_histology/](http://localhost:8888/tree/ml/dataset_oct_histology)<br>
[Github Link](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)<br>
<br>
## Install

In [None]:
# Set up general varibles
root_path = '~/ml/'
dataset_path = root_path + 'dataset_oct_histology/'
code_main_folder = root_path + 'pix2pix_and_CycleGAN/'

# Install environment dependencies
!pip install --upgrade pip
!pip install opencv-python
    
# Get main model
!git clone --single-branch https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix {code_main_folder}
!pip install -r {code_main_folder}requirements.txt

## Build Dataset
This library requires OCT and histology images to be paired together.<br>
Code below merges images from  [patches_256px_256px](http://localhost:8888/tree/ml/dataset_oct_histology/patches_256px_256px) to [patches_256px_256px_combined](http://localhost:8888/tree/ml/dataset_oct_histology/patches_256px_256px_combined)

In [None]:
import os
import shutil
import numpy as np
import cv2

# Generic function to combine two images
def combine_images (img_fold_A, img_fold_B, img_fold_AB):
    img_fold_A = os.path.expanduser(img_fold_A)
    img_fold_B = os.path.expanduser(img_fold_B)
    img_fold_AB = os.path.expanduser(img_fold_AB)
    
    # Setup a clean directory
    def setup_clean_dir(d, is_clean=True):
        if os.path.exists(d) and os.path.isdir(d) and is_clean:
            shutil.rmtree(d)
        if not os.path.exists(d):
            os.mkdir(d)

    # Setup parent directory
    setup_clean_dir(os.path.abspath(os.path.join(img_fold_AB, '..')), is_clean = False)
    # Setup AB directory
    setup_clean_dir(img_fold_AB)

    img_list = os.listdir(img_fold_A)
    num_imgs = len(img_list)
    
    for n in range(num_imgs):
        name_A = img_list[n]
        path_A = os.path.join(img_fold_A, name_A)
        name_B = name_A
        path_B = os.path.join(img_fold_B, name_B)
    
        if os.path.isfile(path_A) and os.path.isfile(path_B):
            name_AB = name_A
            path_AB = os.path.join(img_fold_AB, name_AB)
            im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
            im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
            im_AB = np.concatenate([im_A, im_B], 1)
            cv2.imwrite(path_AB, im_AB)
            
# Set input and output folders
patches_folder = dataset_path + 'patches_256px_256px/' # 'patches_1024px_512px/', 'patches_256px_256px/'
patches_combined_folder = dataset_path + 'patches_256px_256px_combined/'# 'patches_1024px_512px_combined/','patches_256px_256px_combined/'

# Combine train dataset
combine_images(
    img_fold_A = (patches_folder + 'train_A/'),
    img_fold_B = (patches_folder + 'train_B/'),
    img_fold_AB = (patches_combined_folder + 'train/'))

# Delete test set, will add it later to make sure there is no image info leaking
patches_combined_test_folder = (patches_combined_folder + 'test/')
if os.path.exists(patches_combined_test_folder) and os.path.isdir(patches_combined_test_folder):
    shutil.rmtree(patches_combined_test_folder)

## Train
Run code below to train model.<br>
Results can be viewed here: [~/ml/checkpoints/pix2pix/web/index.html](http://localhost:8888/view/ml/checkpoints/pix2pix/web/index.html) as the model trains.<br>

In [None]:
# Default setting includes flip which trains on left-right flips as well.
#'--preprocess crop' allows user to load larger than 256X256 images and just randomly crop the right size for training.
#  this is not as recommended because most of the crops will be just black with no information
# If model is stuck, restart using --continue_train to resume from latest model, --epoch_count <number> to get numbering right.
!python {code_main_folder}train.py --name pix2pix --dataroot {patches_combined_folder} --model pix2pix --checkpoints_dir {root_path}checkpoints --preprocess none



## Test

Main test results can be viewed here: [~/ml/results/pix2pix/test_latest/index.html](http://localhost:8888/view/ml/results/pix2pix/test_latest/index.html) after test command


In [None]:
# Generate test dataset
patches_folder_test = dataset_path + 'patches_1024px_512px/'
patches_combined_folder_test = dataset_path + 'patches_1024px_512px_combined/'
combine_images(
    img_fold_A = (patches_folder_test + 'test_A/'),
    img_fold_B = (patches_folder_test + 'test_B/'),
    img_fold_AB = (patches_combined_folder_test + 'test/'))

# Main test results
#  --preprocess none allows to directly process a non 256x256 images using the convolutional proprety of the network
!python {code_main_folder}test.py --name pix2pix --dataroot {patches_combined_folder_test} --model pix2pix --checkpoints_dir {root_path}checkpoints --results_dir {root_path}results --preprocess none

# Go over test results, apply mask when required 
test_resoults_folder = root_path + 'results/pix2pix/test_latest/images/'
img_list = os.listdir(img_fold_A)

In [None]:
# Post process images 
# Go over test results, apply mask when required 
test_resoults_folder = root_path + 'results/pix2pix/test_latest/images/'
img_list = os.listdir(os.path.expanduser(test_resoults_folder))

img_fake_B_list = [str for str in img_list if '_fake_B' in str]
img_real_B_list = [str for str in img_list if '_real_B' in str]

for i in range(len(img_fake_B_list)):
    img_real_file_path = os.path.expanduser(os.path.join(test_resoults_folder,img_real_B_list[i]))
    img_fake_file_path = os.path.expanduser(os.path.join(test_resoults_folder,img_fake_B_list[i]))
    
    im_real = cv2.imread(img_real_file_path, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
    im_fake = cv2.imread(img_fake_file_path, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
    
    # Generate mask
    is_black_pixel = np.all(im_real == 0, axis=2)
    is_black_pixel = is_black_pixel[...,None]
    is_black_pixel = np.concatenate((is_black_pixel,is_black_pixel,is_black_pixel), axis=2)
    
    # Apply mask
    im_fake1[is_black_pixel] = 0
    
    # Save
    cv2.imwrite(img_fake_file_path, im_fake1)   