### Pix2Pix modeling

Model repo: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

Paper: https://arxiv.org/pdf/1611.07004.pdf

In [2]:
import numpy as np
import matplotlib.pyplot as plt

from skimage import data, img_as_float, img_as_ubyte, io, color
from skimage.measure import compare_ssim
from skimage.measure import compare_mse
from cellvision_lib import train_test_val
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [16]:
# Proprocesing the data for pix2pix model
import os
import glob
from shutil import copyfile

MAX_DEPTH = 100
NUM_SAMPLES = 109

# folder_path = '/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-normalized'
# train, test, val = train_test_val(folder_path, channels = 1, train_pp = .67, test_pp = .165, val_pp = .165, set_seed = 1)

# train[0:10]

def clear_test_files(pix2pix_path):
    outer_paths = ['A','B']
    inner_paths = ['test','train','val']
    for outer in outer_paths:
        for inner in inner_paths:
            files = glob.glob('{root}/{split}/{inner}/*'.format(root=pix2pix_path, split=outer, inner=inner))
            for f in files:
                os.remove(f)
                
def setup_images_for_pix2pix(src_path, channel, num_images):
    print("Setting up {} images for channel {}".format(num_images, channel))
    pix_folder_path = '/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/channel{}_{}'.format(channel,num_images)
    print("At path " + pix_folder_path)
    clear_test_files(pix_folder_path)
    train, test, val = train_test_val(src_path, 
                                      channels = channel, 
                                      train_pp = .67, 
                                      test_pp = .165, 
                                      val_pp = .165, 
                                      set_seed = 1)

    train_images = train[0:num_images]
    test_images = test[0:num_images]
    val_images = val[0:num_images]
    
    def get_pix_fname(base_dir, ref_fname, _channel):
        fname = os.path.basename(comp)
        end_sample_prefix_index = fname.find('_channel{}_z'.format(channel)) 
        start_z_index = end_sample_prefix_index + 11
        end_index = fname.find('.tif') 
        sample_prefix = fname[0:end_sample_prefix_index]
        z_depth = fname[start_z_index:end_index]
        new_comp_path = '{}/A/train/{}_channel{}_z{}.jpg'.format(base_dir, sample_prefix, _channel, z_depth)
        new_ref_path = '{}/B/train/{}_channel{}_z{}.jpg'.format(base_dir, sample_prefix, _channel, z_depth)
        if not os.path.exists('{}/A/train'.format(base_dir)):
            os.makedirs('{}/A/train'.format(base_dir))
        if not os.path.exists('{}/B/train'.format(base_dir)):
            os.makedirs('{}/B/train'.format(base_dir))
        return new_comp_path, new_ref_path
    
    for i, (comp, ref) in enumerate(train_images):
        new_comp_path, new_ref_path = get_pix_fname(pix_folder_path, ref, channel)
        copyfile(comp, new_comp_path)
        copyfile(ref, new_ref_path)
        
    print("done with training images")
    
    for i, (comp, ref) in enumerate(test_images):
        new_comp_path, new_ref_path = get_pix_fname(pix_folder_path, ref, channel)
        copyfile(comp, new_comp_path)
        copyfile(ref, new_ref_path)
        
    print("done with testing images")
        
    for i, (comp, ref) in enumerate(val_images):
        new_comp_path, new_ref_path = get_pix_fname(pix_folder_path, ref, channel)
        copyfile(comp, new_comp_path)
        copyfile(ref, new_ref_path)
        
    print("done with validation images")


folder_path = '/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-top50'
# folder_path = '/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-normalized'

# setup_images_for_pix2pix(folder_path, 1, 100)
setup_images_for_pix2pix(folder_path, 2, 100)
setup_images_for_pix2pix(folder_path, 3, 100)
setup_images_for_pix2pix(folder_path, 4, 100)
setup_images_for_pix2pix(folder_path, 5, 100)

# print()
# print(glob.glob('/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/channel1/A/train/*'))
# print()
# print(glob.glob('/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/channel1/B/train/*'))


Setting up 100 images for channel 2
done with training images
done with testing images
done with validation images
Setting up 100 images for channel 3
done with training images
done with testing images
done with validation images
Setting up 100 images for channel 4
done with training images
done with testing images
done with validation images
Setting up 100 images for channel 5
done with training images
done with testing images
done with validation images


In [None]:
"""
Commands to run with the pix2pix framework

python datasets/combine_A_and_B.py --fold_A /gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/testing/A --fold_B /gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/testing/B --fold_AB /gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/testing --num_imgs 200

bsub -Is -gpu "num=1:mode=exclusive_process:mps=yes" python train.py --dataroot /gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/testing --name cellvision5 --model pix2pix --direction AtoB --gpu 0 --display_id 0

bsub -Is -gpu "num=1:mode=exclusive_process:mps=yes" python test.py --dataroot /gpfs/data/lionnetlab/cellvision/pilotdata/20181009-pix2pix/testing --name cellvision5 --model pix2pix --direction AtoB

"""

In [17]:
import glob

results_dir = '/home/dg3047/capstone/pytorch-CycleGAN-and-pix2pix/results/cellvision5/test_latest/images'

# out_imgs = glob.glob(results_dir)

mses = []
ssims = []

for i in range(1,100):
    real_path_low = '{}/{}_real_A.png'.format(results_dir,i)
    real_path_high = '{}/{}_real_B.png'.format(results_dir,i)
    fake_path = '{}/{}_fake_B.png'.format(results_dir,i)
    
    real_img_high = color.rgb2gray(io.imread(real_path_high).astype(np.uint))
    real_img_low = color.rgb2gray(io.imread(real_path_low).astype(np.uint))
    fake_img = color.rgb2gray(io.imread(fake_path).astype(np.uint))
    
    _min = real_img_high_ft.min()
    _max = real_img_high_ft.max()

    low_high_ssim = compare_ssim(real_img_low, real_img_high, data_range=_max-_min)
    fake_high_ssim = compare_ssim(fake_img, real_img_high, data_range=_max-_min)
    ssims.append( (low_high_ssim,fake_high_ssim) )
    
    low_high_mse = compare_mse(real_img_low, real_img_high)
    fake_high_mse = compare_mse(fake_img, real_img_high)
    
    mses.append( (low_high_mse,fake_high_mse) )


ssims[0:2]



[(0.05714369656897892, 0.7171568099233542),
 (0.038133483581749734, 0.7674732462544108),
 (0.03288369849270328, 0.7554218806931998),
 (0.018270039938528673, 0.7620264046911202),
 (0.0545751440235492, 0.6839410853367558),
 (0.046952584255868485, 0.7304409837197041),
 (0.04147598324822782, 0.7056266584132316),
 (0.01306456762334604, 0.7515882475204598),
 (0.05609789309534377, 0.691740616354918),
 (0.02147631863767301, 0.7468515910674022),
 (0.02376342879878884, 0.7143934339438036),
 (0.03382621884363293, 0.7376405876895635),
 (0.018309694071439307, 0.7731926511276043),
 (0.07481906893185049, 0.6711480014641374),
 (0.013846834119998977, 0.7492556868119543),
 (0.0137544923348324, 0.7594305151533425),
 (0.05010281400588325, 0.7349324320381804),
 (0.008771227152750456, 0.7557971005435014),
 (0.01956554404470542, 0.7613542851160797),
 (0.04235491397497991, 0.7069071956328872),
 (0.0198487963839545, 0.7619011574941551),
 (0.06503811806811458, 0.6949993743729708),
 (0.02291022754750139, 0.68898

In [24]:
import statistics
low = [ssim[0] for ssim in ssims]
low_std = statistics.stdev(low)
low_mean = statistics.mean(low)

fake = [ssim[1] for ssim in ssims]
fake_std = statistics.stdev(fake)
fake_mean = statistics.mean(fake)

print("Low ssim mean {} with std {}".format(low_mean,low_std))
print("Fake ssim mean {} with std {}".format(fake_mean,fake_std))

low_mse = [mse[0] for mse in mses]
low_std_mse = statistics.stdev(low_mse)
low_mean_mse = statistics.mean(low_mse)

fake_mse = [mse[1] for mse in mses]
fake_std_mse = statistics.stdev(fake_mse)
fake_mean_mse = statistics.mean(fake_mse)
print()
print("Low mse mean {} with std {}".format(low_mean_mse,low_std_mse))
print("Fake mse mean {} with std {}".format(fake_mean_mse,fake_std_mse))


Low ssim mean 0.03523206115861787 with std 0.015971007061636254
Fake ssim mean 0.7272434952244862 with std 0.03228345536195664

Low mse mean 2.9115057547079674e-37 with std 1.7443831171497942e-37
Fake mse mean 1.2529523666832592e-38 with std 9.488763181891202e-39
