# Create JPegImages folder and populate it with .jpg versions of the .tiff files

In [None]:
import numpy as np
import matplotlib
matplotlib.use('Agg')
%matplotlib inline
from matplotlib import pyplot as plt
import tifffile as tiff
import os
import itertools
import PIL
from PIL import Image

The direction `../data` should contain links to `/data/dstl/three_band/` as `dstl` and to `/data/VOCdevkit2007/VOC2007/` as `voc` (but that's not used in this notebook).

In [None]:
data_dir = os.path.join(os.getcwd(), '..', 'data', 'dstl')
tiff_loc = os.path.join(data_dir, 'TIFFImages')
jpeg_loc = os.path.join(data_dir, 'JPegImages')
print data_dir

In [None]:
def convert(filename):
    print filename
    im = tiff.imread(os.path.join(tiff_loc, filename + '.tiff'))
    im>>= 3 # downsample to 8 bits per pixel
    im_rgb = Image.fromarray(im.astype(np.uint8).transpose((1, 2, 0)))
    im_rgb.save(os.path.join(jpeg_loc, filename + '.jpg'))

In [None]:
if False: # this approach is valid, but for white balancing it is better to do a smarter loop
    for fn in os.listdir(tiff_loc):
        if os.path.isfile(os.path.join(tiff_loc, fn)):
            filename, file_extension = os.path.splitext(fn)

            if file_extension == '.tiff' and not os.path.isfile(os.path.join(jpeg_loc, filename + '.jpg')):
                convert(filename)

The function `find_scalers` returns the minimum and maximum values for each color channel to ensure that 

In [None]:
def find_scalers(hists, plot_histograms=False):
    N_pixels = hists.sum()/3 # hists contains a histogram for each channel, ergo /3
    c_min = N_pixels/2000 # such that 0.05% of pixels is cropped at either end
    c_max = N_pixels-c_min
    
    x_mins = np.zeros((3,), dtype=int)
    x_maxs = np.zeros((3,), dtype=int)
    
    for color in range(3):
        cdf = np.zeros_like(bins)
        cdf[1:] = hists[color, :].cumsum()
        x_mins[color] = bins[cdf < c_min][-1]
        x_maxs[color] = bins[cdf > c_max][0]
        
        if plot_histograms:
            plt.figure()
            plt.plot(mids, hists[color, :])
            plt.plot((x_mins[color], x_mins[color]), (0, hists[color, :].max()))
            plt.plot((x_maxs[color], x_maxs[color]), (0, hists[color, :].max()))

            plt.figure()
            plt.plot(bins, cdf)
            plt.plot((x_mins[color], x_mins[color]), (0, cdf[-1]))
            plt.plot((x_maxs[color], x_maxs[color]), (0, cdf[-1]))
            
    return x_mins, x_maxs
    

In [None]:
def fix_white_balance(im, x_mins, x_maxs, plot_histograms=False):
    for channel, x_min, x_max in zip(im, x_mins, x_maxs):
        channel_uint32 = np.clip(channel, x_min, x_max-1).astype(np.uint32) # -1 to be able to multiply with 256?
        channel_uint32-= x_min
        channel_uint32*= 256
        channel_uint32/= x_max-x_min
        channel[...] = channel_uint32
        assert channel.max() < 256 # so that it will fit into uint8 later on
        
        if plot_histograms:
            bins = np.arange(2**8, dtype=int) # 8 bit images
            mids = .5*(bins[:-1] + bins[1:])
            hist = np.histogram(channel, bins=bins)[0]
            
            plt.figure()
            plt.plot(mids, hist)
        
        
        

In [None]:
bins = np.arange(2**11, dtype=int) # 11 bit images
mids = .5*(bins[:-1] + bins[1:])

cum_mean = np.zeros((3,), dtype=float)

for collage_iter, collage_id in enumerate(range(6010, 6181, 10)[1:]):
    # each image is part of a 'collage' of 5x5=25 images, we do the white balance for the whole collage
    ims = []
    filenames = []
    hists = np.zeros((3, len(mids)))
    print "Working on collage {:d}".format(collage_iter)
    for i, j in itertools.product(range(5), range(5)):
        print "pair ({:d}, {:d})".format(i, j),
        filename = '{:d}_{:d}_{:d}_RGB'.format(collage_id, i, j)
        filenames.append(filename)
        
        im = tiff.imread(os.path.join(tiff_loc, filename + '.tiff'))
        ims.append(im)
        
        for color in range(3):
            hists[color, :]+= np.histogram(im[color, :, :], bins=bins)[0]
        
    x_mins, x_maxs = find_scalers(hists, plot_histograms=False)
    print "\n Found scalers for collage"
    
    for im, filename in zip(ims, filenames):
        fix_white_balance(im, x_mins, x_maxs, plot_histograms=False)
        cum_mean+= im.mean(axis=(1,2)) # not over the color axis
        im_rgb = Image.fromarray(im.astype(np.uint8).transpose((1, 2, 0)))
        im_rgb.save(os.path.join(jpeg_loc, filename + '.jpg'))
        print "Saved file {:s}".format(filename + '.jpg')
    print cum_mean/(25*(collage_iter+1))


        
    

In [None]:
print cum_mean/(25*(collage_iter+1))