# Freq20-MNIST1/2 dataset generation script.

#### Prepare the import & setup the directory

In [None]:
from __future__ import division, print_function, unicode_literals
import numpy as np
import matplotlib.pyplot as plt
import h5py
from fuel.datasets import MNIST
from scipy.ndimage import convolve
from scipy.ndimage.filters import gaussian_filter

from itertools import product
from fuel.converters.base import fill_hdf5_file
from fuel.datasets import IndexableDataset

import os.path
import os
%matplotlib inline

In [None]:
data_dir = './'
%cd ..

# please make sure we are at the root directory of the project.
%pwd

### Texture Generation Functions.

In [None]:
# gen texture
def gen_textures(freqs, size, func, normalize=True):
    """ Generate 2D textures from frequency tuples
    freqs: (n_textures, locs_per_texture, 2)
    """
    N = len(freqs)
    S = size
    z = np.zeros((N, S, S), 'complex64')
    for i in range(N):
        for loc in freqs[i]:
            z[i, loc[0], loc[1]] += func()

    spat = np.float32(np.real(np.fft.ifft2(z)))
    if normalize:
        spat -= spat.min((1, 2), keepdims=True)
        spat /= spat.max((1, 2), keepdims=True)
    return spat

def gauss_filt(fs, std):
    inp = np.zeros((fs, fs))
    inp[fs//2, fs//2] = 1.
    return gaussian_filter(inp, std)


# choose_freqs
def choose_freqs(n, n_freqs_per_texture, texture_size, conv_filt_size, conv_filt_std, rng=np.random):
    S = texture_size

    # Special case for a pair of single freq textures
    if n <= 2 and n_freqs_per_texture == 1:
        too_close = True
        retries_left = 100
        while too_close and retries_left > 0:
            # Initial guess
            x = rng.randint(low=1, high=S-1, size=(n))
            y = rng.randint(low=1, high=S//2-4, size=(n))
            if True:
                diffs = np.array([np.abs(x[i] - x[j]) for i in range(n) for j in range(i+1, n)])

                assert len(diffs) <= 1
                if np.any(diffs < 10):
                    retries_left -= 1
                    continue
            too_close = False
        res = np.concatenate([x[None], y[None]], 0).T

        if retries_left == 0:
            logger.warning('Used all the retries in data generation. Still did not get good one')

        return res[:, None, :]

    # Special case for three freqs per observation
    if n == 3 and n_freqs_per_texture == 1:
        # Y is always random
        y = rng.randint(low=1, high=S//2-4, size=(n))
        # X is handled in n segments
        min_dist = S/5
        x = -min_dist + 1
        xs = []
        for s in range(n):
            high = S - 1 - min_dist * (n - 1 - s)
            low = x + min_dist
            x = rng.randint(low=low, high=high, size=(1))
            xs += [x]
        x = rng.permutation(np.array(xs)[:, 0])
        y = rng.randint(low=1, high=S//2-4, size=(n))
        assert x.shape == y.shape
        res = np.concatenate([x[None], y[None]], 0).T
        return res[:, None, :]

    N = n * 10

    # Initial guess
    freqs = rng.randint(low=1, high=S-1, size=(N, n_freqs_per_texture, 2))

    texture_candidates = gen_textures(freqs, S, lambda: 1)
    # Get symmetric frequencies
    symm = np.abs(np.fft.fft2(texture_candidates))
    # remove DC
    symm[:, 0, 0] = 0
    # Expand the points with a gaussian filter
    filt = gauss_filt(conv_filt_size, conv_filt_std)
    symm_conv = convolve(symm, filt[None])

    permutations = [rng.permutation(np.arange(N)) for _ in range(10)]

    max_val = np.max(symm_conv)

    best_indices = []
    for perm in permutations:
        final_indices = []
        mask = np.zeros((S, S))
        for i in perm:
            im = symm_conv[i]
            if np.max(mask + im) > max_val:
                continue
            final_indices += [i]
            mask += im
        if len(final_indices) > len(best_indices):
            best_indices = final_indices
    return freqs[best_indices]

### Texture + MNIST script.

In [None]:
class TexturedMNIST(IndexableDataset):
    def __init__(self, which_sets, n_digits=1, scale=1., textures=None,
                 **kwargs):
        assert len(which_sets) == 1, 'Only one concurrent set implemented'
        seed_inc, n = {'train': (0, 50000),
                       'valid': (1, 10000),
                       'test': (2, 10000)}[which_sets[0]]
        rng = np.random.RandomState(1 + seed_inc)

        if which_sets == ['valid']:
            rang = range(50000, 60000)
            which_sets = ['train']
        else:
            rang = range(n)

        # Get the whole MNIST data and prepare a dataset
        mnist = MNIST(which_sets)
        orig_feats, labs = mnist.get_data(None, rang)

        S = orig_feats.shape[-1]
        if scale == 1.:
            pass
        elif scale == 2.:
            S = int(S * scale)
        else:
            raise NotImplemented
        N = len(rang)
        m = n_digits
        if textures is None:
            codes = TexturedMNIST.create_textures(N, S, m + 1, rng)
        else:
            codes = TexturedMNIST.load_textures(textures, N, S, m + 1, rng)

        masks_independent, perm, perm_top_bottom = TexturedMNIST.create_masks(orig_feats, scale, m, rng)
        labs = np.concatenate([labs[p] for p in perm], axis=1)
        if perm_top_bottom is not None:
            assert labs.shape[1] == 2
            perm_top_bottom = perm_top_bottom.astype(labs.dtype).reshape(perm_top_bottom.shape[0:2])
            labs_inv = labs[:, ::-1]
            labs = (perm_top_bottom * labs + (1 - perm_top_bottom) * labs_inv).astype(labs.dtype)
        feats, masks = TexturedMNIST.compose(codes, masks_independent)

        self.sources = ['features', 'targets', 'mask', 'codes']
        self.data_sources = [feats, labs, masks, codes]
        super(TexturedMNIST, self).__init__(
            indexables={'features': feats, 'targets': labs,
                        'mask': masks, 'codes': codes}, **kwargs)

    @staticmethod
    def create_masks(feats, scale, m, rng, shift=True):
        if scale != 1.:
            W, H = feats.shape[-2:]
            feats2x = np.ones(feats.shape[:2] + (W*scale, H*scale))
            feats2x[:, :, 0::2, 0::2] = feats
            feats2x[:, :, 1::2, 0::2] = feats
            feats2x[:, :, 0::2, 1::2] = feats
            feats2x[:, :, 1::2, 1::2] = feats
            feats = feats2x
        mask = np.float32(feats / 1.)
        assert np.min(mask) == 0
        print(np.max(mask))
        assert np.max(mask) == 1.
        if m == 1:
            if shift:
                mask = TexturedMNIST.shift_ims(mask, -shift, -shift)
            return mask[:, None], [np.asarray(range(mask.shape[0]))], None
        elif m == 2:
            shift = 2 * scale
            perm = [rng.permutation(mask.shape[0]), rng.permutation(mask.shape[0])]
            m1 = mask[perm[0]]
            m2 = mask[perm[1]]
            if shift:
                m1 = TexturedMNIST.shift_ims(m1, -shift, -shift)
                m2 = TexturedMNIST.shift_ims(m2, shift, shift)

            # Shuffle top vs. bottom digit
            m_12 = np.concatenate([m1[:, None], m2[:, None]], 1)
            m_21 = np.concatenate([m2[:, None], m1[:, None]], 1)
            perm_top_bottom = np.reshape(rng.binomial(1, 0.5, mask.shape[0]), (mask.shape[0],) + (1,) * mask.ndim)
            m = np.float32(perm_top_bottom * m_12 + (1 - perm_top_bottom) * m_21)
            return m, perm, perm_top_bottom

    @staticmethod
    def create_textures(N, S, m, rng):
        freqs = [choose_freqs(m, n_freqs_per_texture=1, texture_size=S,
                              conv_filt_size=5, conv_filt_std=1, rng=rng) for _ in range(N)]

        # Squash the first two dims
        freqs = np.concatenate(freqs, 0)
        textures = gen_textures(freqs, S, lambda: np.exp(2j * np.pi * rng.rand()))
        return textures.reshape(N, m, 1, S, S)

    @staticmethod
    def load_textures(textures_name, N, S, m, rng):
        codes = np.zeros((N * m, S, S), dtype=np.float32)
        with h5py.File(textures_name, 'r') as f:
            textures = f['textures'][:]
            nr_textures, _, W, H = textures.shape
            idxs = rng.randint(0, nr_textures, N*m)
            w_offsets = rng.randint(0, W-S, N*m)
            h_offsets = rng.randint(0, H-S, N*m)
            for i, (j, w, h) in enumerate(zip(idxs, w_offsets, h_offsets)):
                codes[i] = textures[j, 0, h:h+S, w:w+S]
        return codes.reshape((N, m, 1, S, S))

    @staticmethod
    def compose(textures, masks_independent):
        """ Use the first texture as bg """
        # textures.shape == (N, n_layers, ...)
        assert textures.shape[1] == masks_independent.shape[1] + 1

        masks = np.ones(textures.shape, dtype='float32')
        res = textures[:, 0]
        for i in range(1, textures.shape[1]):
            m = masks_independent[:, i - 1]
            masks[:, i] = m
            for j in range(i):
                masks[:, j] *= 1 - m
            t = textures[:, i]
            res = m * t + (1 - m) * res
        assert np.allclose(np.sum(masks, axis=1), 1)
        return res, masks

    @staticmethod
    def shift_ims(ims, x, y):
        assert x % 1 == 0, 'Shift has to be integer'
        assert y % 1 == 0, 'Shift has to be integer'
        (x, y) = (int(x), int(y))
        # TODO: This would be faster in uint8
        assert ims.ndim == 4
        if x > 0:
            x_slice = np.zeros((ims.shape[0], ims.shape[1], ims.shape[2], x),
                                  dtype='float32')
            ims = np.concatenate([x_slice, ims[:, :, :, :-x]], 3)
        if x < 0:
            x_slice = np.zeros((ims.shape[0], ims.shape[1], ims.shape[2], -x),
                                  dtype='float32')
            ims = np.concatenate([ims[:, :, :, -x:], x_slice], 3)

        if y > 0:
            y_slice = np.zeros((ims.shape[0], ims.shape[1], y, ims.shape[3]),
                                  dtype='float32')
            ims = np.concatenate([y_slice, ims[:, :, :-y, :]], 2)
        if y < 0:
            y_slice = np.zeros((ims.shape[0], ims.shape[1], -y, ims.shape[3]),
                                  dtype='float32')
            ims = np.concatenate([ims[:, :, -y:, :], y_slice], 2)

        return ims

### Simple Visualization Script

In [None]:
def show(ims, titles=None):
    if not isinstance(ims, (tuple, list)):
        ims = [ims]
    if titles is None:
        titles = [''] * len(ims)
    if not isinstance(titles, (tuple, list)):
        titles = [titles]
        
    H = int(np.sqrt(len(ims)))
    W = int(round(len(ims) / H))
    #print("{} x {} = {} >= {}".format(W, H, W*H, len(ims)))
    fig, axes = plt.subplots(nrows=H, ncols=W, figsize=(3*W, 3*H))
    if len(ims) == 1:
        axes = np.array([axes])
    for im, ax, title in zip(ims, axes.flatten(), titles):
        if len(im.shape) == 3:
            assert im.shape[0] == 1
            im = im[0]
        ax.matshow(im, cmap="gray")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title)

# Generate Textures

In [None]:
rng = np.random.RandomState(546298)

In [None]:
freqs = np.array(list(product([ 4, 10, 21, 32, 37], [ 0,  4, 8, 12])))

freqs = freqs[:20].reshape((5, 4, 2))
freqs[:, :, 1] += np.arange(freqs.shape[0])[:, None]
freqs = np.concatenate(freqs, 0)[:, None, :]
freqs.shape

In [None]:
textures = gen_textures(freqs, np.max(freqs) + 1, lambda: np.exp(2j * np.pi * rng.rand()))
show([t for t in textures], titles=['({},{})'.format(f1, f2) for f1, f2 in freqs[:, 0]])


In [None]:
def prepare_data(nr_digits, scale, orig_feats, labs, textures, shift, rng=None):
    rng = np.random.RandomState(14579) if rng is None else rng
        
    S = int(scale * orig_feats.shape[-1])
    tex_S = textures.shape[-1]
    N = orig_feats.shape[0]
    allN = np.arange(N)
    
    masks, mask_idxs, perm_top_bottom = TexturedMNIST.create_masks(orig_feats, scale, nr_digits, rng, shift=shift)
    mask_idxs = np.array(mask_idxs)
    print(mask_idxs.shape)
    #mask_idxs = np.concatenate([rng.permutation(allN)[None] for i in range(nr_digits)], 0) 

    if nr_digits == 2:
        z_order = np.concatenate([perm_top_bottom, (1 - perm_top_bottom)], axis=1)
        z_order = np.transpose(np.reshape(z_order, z_order.shape[0:2]), (1, 0))
    else:
        assert nr_digits == 1
        z_order = np.zeros((nr_digits, N), dtype=np.int)
        for i in allN:
            z_order[:, i] = rng.permutation(np.arange(nr_digits))  
    
    targets = np.concatenate([labs[mask_idxs[(z_order[i], allN)]] for i in range(nr_digits)], 1) 
    shuffled_masks = np.concatenate([masks[(allN, z_order[i])][:, :, None] for i in range(nr_digits)], 1)
    
    
    idxs = np.arange(textures.shape[0])
    code_idxs = np.array([rng.choice(idxs, size=(nr_digits+1), replace=False) for i in allN])
    
    rnd = np.random.RandomState(598)

    codes = np.zeros((N, nr_digits+1, 1, S, S))
    for i in range(N):
        for j in range(nr_digits+1):
            xs, ys = rnd.randint(0, tex_S - S + 1, 2)
            codes[i, j, 0, :, :] = textures[code_idxs[i, j], xs:xs+S, ys:ys+S]
    
    feats, masks = TexturedMNIST.compose(codes, shuffled_masks)
    
    return feats.astype(np.float32), targets.astype(np.int64), code_idxs.astype(np.int64), masks.astype(np.float32)

In [None]:
def save_as_fuel(filename, feats, code_idxs, targets, masks, 
                 feats_test, targets_test, code_idxs_test, masks_test, 
                 trainsize=50000, valsize=10000):
    split = (
        ('train', 'features', feats[:trainsize]),
        ('train', 'mask', masks[:trainsize]),
        ('train', 'codes', code_idxs[:trainsize]),
        ('train', 'targets', targets[:trainsize]),

        ('valid', 'features', feats[trainsize:trainsize+valsize]),
        ('valid', 'mask', masks[trainsize:trainsize+valsize]),
        ('valid', 'codes', code_idxs[trainsize:trainsize+valsize]),
        ('valid', 'targets', targets[trainsize:trainsize+valsize]),

        ('test', 'features', feats_test),
        ('test', 'mask', masks_test),
        ('test', 'codes', code_idxs_test),
        ('test', 'targets', targets_test),
    )

    with h5py.File(os.path.join(data_dir, filename), mode='w') as f:
        fill_hdf5_file(f, split)


In [None]:
def compress_fuel(source_filename, target_filename):
    source = h5py.File(os.path.join(data_dir, source_filename), mode='r')
    target = h5py.File(os.path.join(data_dir, target_filename), mode='w')
    for data in source:
        print('converting {}'.format(data))
        target.create_dataset(data, data=source[data][:], compression='gzip')
    for attr in source.attrs:
        target.attrs[attr] = source.attrs[attr]

In [None]:
# Get the whole MNIST data
mnist = MNIST(['train'])
orig_feats, labs = mnist.get_data(None, range(60000))

mnist = MNIST(['test'])
test_feats, test_labs = mnist.get_data(None, range(10000))

# Normal 2 digit dataset

In [None]:
nr_digits = 2
scale = 1
S = int(scale * orig_feats.shape[-1])

In [None]:
feats, targets, code_idxs, masks = prepare_data(nr_digits, scale, orig_feats, labs, textures, True,
                                                rng=np.random.RandomState(52698))

In [None]:
subs = slice(2, 6)
show(list(feats[subs]) + list(masks[subs, 0])  + list(masks[subs, 1])  + list(masks[subs, 2]))

### Test Data

In [None]:
feats_test, targets_test, code_idxs_test, masks_test = prepare_data(nr_digits, scale, test_feats, test_labs, textures, True,
                                                                    rng=np.random.RandomState(2938))

In [None]:
subs = slice(2, 6)
show(list(feats_test[subs]) + list(masks_test[subs, 0])  + list(masks_test[subs, 1])  + list(masks_test[subs, 2]))

## Save

In [None]:
# Save as Fuel HDF5
save_as_fuel('freq20-2MNIST.h5', feats, code_idxs, targets, masks,
             feats_test, targets_test, code_idxs_test, masks_test,
             50000, 10000)
# Compress
compress_fuel('freq20-2MNIST.h5', 'freq20-2MNIST_compressed.h5')
!rm {os.path.join(data_dir, 'freq20-2MNIST.h5')}

## 1 digit Creation

In [None]:
nr_digits = 1
scale = 1
S = int(scale * orig_feats.shape[-1])

feats_single, targets_single, code_idxs_single, masks_single = prepare_data(
    nr_digits, scale, orig_feats, labs, textures, False, rng=np.random.RandomState(52698))

subs = slice(2, 6)
show(list(feats_single[subs]) + list(masks_single[subs, 0])  + list(masks_single[subs, 1]))
feats_single_test, targets_single_test, code_idxs_single_test, masks_single_test = prepare_data(nr_digits, scale, test_feats, test_labs, textures, False,
                                                                    rng=np.random.RandomState(2938))

save_as_fuel('freq20-1MNIST.h5', feats_single, code_idxs_single, targets_single, masks_single,
             feats_single_test, targets_single_test, code_idxs_single_test, masks_single_test,
             50000, 10000)

compress_fuel('freq20-1MNIST.h5', 'freq20-1MNIST_compressed.h5')
!rm {os.path.join(data_dir, 'freq20-1MNIST.h5')}