In [None]:
from __future__ import division, print_function

import matplotlib.pyplot as plt
%pylab inline
import sys
import logging
import numpy as np

sys.path.append('/work/draw')
FORMAT = '[%(asctime)s] %(name)-15s %(message)s'
DATEFMT = "%H:%M:%S"
logging.basicConfig(format=FORMAT, datefmt=DATEFMT, level=logging.INFO)

#from theano.misc.pkl_utils import dump
#from theano.misc.pkl_utials import load
import os
import theano
import theano.tensor as T
import fuel
import ipdb
import time
import cPickle
import h5py

import random
from skimage.filters import threshold_adaptive
from skimage.draw import line

from argparse import ArgumentParser
from theano import tensor

from fuel.streams import DataStream
from fuel.schemes import SequentialScheme
from fuel.transformers import Flatten

from blocks.algorithms import GradientDescent, CompositeRule, StepClipping, RMSProp, Adam
from blocks.bricks import Tanh, Identity, MLP, Linear, Rectifier, Softmax, Logistic
from blocks.bricks.cost import BinaryCrossEntropy, CategoricalCrossEntropy
from blocks.bricks.recurrent import SimpleRecurrent, LSTM
from blocks.initialization import Constant, IsotropicGaussian, Orthogonal 
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph
from blocks.roles import PARAMETER
from blocks.monitoring import aggregation
from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar
from blocks.extensions.saveload import Checkpoint
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.main_loop import MainLoop
from blocks.model import Model

####TODO get extras working
try:
    from blocks.extras import Plot
except ImportError:
    pass


import draw.datasets as datasets
from draw.draw import *
from draw.samplecheckpoint import SampleCheckpoint
from draw.partsonlycheckpoint import PartsOnlyCheckpoint

sys.setrecursionlimit(100000)

In [None]:
def floatX(X):
    return np.asarray(X, dtype=theano.config.floatX)

def clip_norm(g, c, n): 
    '''n is the norm, c is the threashold, and g is the gradient'''
    
    if c > 0: 
        g = T.switch(T.ge(n, c), g*c/n, g) 
    return g

def clip_norms(gs, c):
    norm = T.sqrt(sum([T.sum(g**2) for g in gs]))
    return [clip_norm(g, c, norm) for g in gs]

# Regularizers
def max_norm(p, maxnorm = 0.):
    if maxnorm > 0:
        norms = T.sqrt(T.sum(T.sqr(p), axis=0))
        desired = T.clip(norms, 0, maxnorm)
        p = p * (desired/ (1e-7 + norms))
    return p

def gradient_regularize(p, g, l1 = 0., l2 = 0.):
    g += p * l2
    g += T.sgn(p) * l1
    return g

def weight_regularize(p, maxnorm = 0.):
    p = max_norm(p, maxnorm)
    return p

def Adam(params, cost, lr=0.0002, b1=0.1, b2=0.001, e=1e-8, l1 = 0., l2 = 0., maxnorm = 0., c = 8):
    
    updates = []
    grads = T.grad(cost, params)
    grads = clip_norms(grads, c)
    
    i = theano.shared(floatX(0.))
    i_t = i + 1.
    fix1 = 1. - b1**(i_t)
    fix2 = 1. - b2**(i_t)
    lr_t = lr * (T.sqrt(fix2) / fix1)
    
    for p, g in zip(params, grads):
        m = theano.shared(p.get_value() * 0.)
        v = theano.shared(p.get_value() * 0.)
        m_t = (b1 * g) + ((1. - b1) * m)
        v_t = (b2 * T.sqr(g)) + ((1. - b2) * v)
        g_t = m_t / (T.sqrt(v_t) + e)
        g_t = gradient_regularize(p, g_t, l1=l1, l2=l2)
        p_t = p - (lr_t * g_t)
        p_t = weight_regularize(p_t, maxnorm=maxnorm)
        
        updates.append((m, m_t))
        updates.append((v, v_t))
        updates.append((p, p_t))
    
    updates.append((i, i_t))
    return updates


def RMSprop(params, cost, lr = 0.001, l1 = 0., l2 = 0., maxnorm = 0., rho=0.9, epsilon=1e-6, c = 8):
    
    grads = T.grad(cost, params)
    grads = clip_norms(grads, c)
    updates = []
    
    for p, g in zip(params, grads):
        g = gradient_regularize(p, g, l1 = l1, l2 = l2)
        acc = theano.shared(p.get_value() * 0.)
        acc_new = rho * acc + (1 - rho) * g ** 2
        updates.append((acc, acc_new))
        
        updated_p = p - lr * (g / T.sqrt(acc_new + epsilon))
        updated_p = weight_regularize(updated_p, maxnorm = maxnorm)
        updates.append((p, updated_p))
    return updates

In [None]:
iamMain = h5py.File('/fileserver/iam/iam-binary/iam_author_lines_bin.hdf5', 'r')

In [None]:
frags = []
auths = []
i=0
sizeCutHt=100
sizeCutWt=200
for auth in iamMain.keys():
    for frag in iamMain[auth]:
        fragShape = np.asarray(iamMain[auth][frag]).shape
        if fragShape[0]>sizeCutHt and fragShape[1]>sizeCutWt:
            image4frags = np.asarray(iamMain[auth][frag])
            frags.append(image4frags)
            auths.append(str(auth))
        else:
            pass
    i+=1
    if i%100==0:
        print('auth %s is done' % auth)

In [None]:
import random
def get_random_patch(images, patchSize, varThresh = 0.03):
    patches = []
    
    if len(images)!=32:
        ht, wt = images.shape
        htIndx = ht - patchSize[0]
        wtIndx = wt - patchSize[1]
        randHt = random.sample(xrange(htIndx), 1)[0]
        randWt = random.sample(xrange(wtIndx), 1)[0]
        vari = 0
        #imgVar = np.var(images/255.0)
        
        while vari<=varThresh:#*imgVar:
            randHt = random.sample(xrange(htIndx), 1)[0]
            randWt = random.sample(xrange(wtIndx), 1)[0]
            imgPatch = images[randHt:randHt+patchSize[0], randWt:randWt+patchSize[1]]
            vari = np.var(imgPatch/255.0)
        #imgPatch = threshold_adaptive(imgPatch, 40, offset = 0.1 )
        #print(vari)
        patches = imgPatch.flatten()/255.0#.flatten())
    
    else:
        for img in images:
            ht, wt = img.shape
            htIndx = ht - patchSize[0]
            wtIndx = wt - patchSize[1]
            randHt = random.sample(xrange(htIndx), 1)[0]
            randWt = random.sample(xrange(wtIndx), 1)[0]
            vari = 0
            #imgVar = np.var(img/255.0)
            
            while vari<=varThresh:#*imgVar:
                randHt = random.sample(xrange(htIndx), 1)[0]
                randWt = random.sample(xrange(wtIndx), 1)[0]
                imgPatch = img[randHt:randHt+patchSize[0], randWt:randWt+patchSize[1]]
                vari = np.var(imgPatch/255.0)
            
            #imgPatch = threshold_adaptive(imgPatch, 40, offset = 0.1 )
            #print(vari)
            patches.append(imgPatch.flatten()/255.0)#.flatten())
    
    return patches

In [None]:
def pickleModel(trainedClassifier, file2save2 = None, filePath='/work/notebooks/drawModels/', fileName = 'myModels'):
    
    if file2save2 == None:
        finFile = filePath+fileName+'.zip'
    else:
        finFile = filePath+file2save2
    
    with open(finFile, 'wb') as f:
        cPickle.dump(trainedClassifier, f)

        
def loadModel(filePath):
    with open(filePath, 'rb') as f:
        model = cPickle.load(f)
    
    return model

    

In [None]:
load = False
name = 'draw'
dataset = 'iam'
epochs = 200000
batch_size = 32
learning_rate = 0.001
attention = '15,15' #number of read, write filters
n_iter = 64 #glimpses
enc_dim = 800 #timesteps for encoder RNN
dec_dim = 800 #timesteps for deocder RNN
z_dim = 200 #
image_size = (100,200)
channels = 1

In [None]:
img_height, img_width = image_size
x_dim = channels * img_height * img_width

mlpinits = {
    #'weights_init': Orthogonal(),
    'weights_init': IsotropicGaussian(0.01),
    'biases_init': Constant(0.),
}
    
rnninits = {
    #'weights_init': Orthogonal(),
    'weights_init': IsotropicGaussian(0.01),
    'biases_init': Constant(0.),
}
inits = {
    #'weights_init': Orthogonal(),
    'weights_init': IsotropicGaussian(0.01),
    'biases_init': Constant(0.),
}

# Configure attention mechanism
if attention != "":
    read_N, write_N = attention.split(',')

    read_N = int(read_N)
    write_N = int(write_N)
    read_dim = 2 * channels * read_N ** 2

    reader = AttentionReader(x_dim=x_dim, dec_dim=dec_dim,
                             channels=channels, width=img_width, height=img_height,
                             N=read_N, **inits)
    writer = AttentionWriter(input_dim=dec_dim, output_dim=x_dim,
                             channels=channels, width=img_width, height=img_height,
                             N=write_N, **inits)
    attention_tag = "r%d-w%d" % (read_N, write_N)
else:
    read_dim = 2*x_dim

    reader = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
    writer = Writer(input_dim=dec_dim, output_dim=x_dim, **inits)

    attention_tag = "full"

#----------------------------------------------------------------------

if name is None:
    name = dataset

# Learning rate
def lr_tag(value):
    """ Convert a float into a short tag-usable string representation. E.g.:
        0.1   -> 11
        0.01  -> 12
        0.001 -> 13
        0.005 -> 53
    """
    exp = np.floor(np.log10(value))
    leading = ("%e"%value)[0]
    return "%s%d" % (leading, -exp)

lr_str = lr_tag(learning_rate)

subdir = name + "-" + time.strftime("%Y%m%d-%H%M%S");
longname = "%s-%s-t%d-enc%d-dec%d-z%d-lr%s" % (dataset, attention_tag, n_iter, enc_dim, dec_dim, z_dim, lr_str)
pickle_file = subdir + "/" + longname + ".pkl"

print("\nRunning experiment %s" % longname)
print("               dataset: %s" % dataset)
print("          subdirectory: %s" % subdir)
print("         learning rate: %g" % learning_rate)
print("             attention: %s" % attention)
print("          n_iterations: %d" % n_iter)
print("     encoder dimension: %d" % enc_dim)
print("           z dimension: %d" % z_dim)
print("     decoder dimension: %d" % dec_dim)
print("            batch size: %d" % batch_size)
print("                epochs: %d" % epochs)
print()

#----------------------------------------------------------------------

encoder_rnn = LSTM(dim=enc_dim, name="RNN_enc", **rnninits)
decoder_rnn = LSTM(dim=dec_dim, name="RNN_dec", **rnninits)
encoder_mlp = MLP([Identity()], [(read_dim+dec_dim), 4*enc_dim], name="MLP_enc", **inits)
decoder_mlp = MLP([Identity()], [             z_dim, 4*dec_dim], name="MLP_dec", **inits)
q_sampler = Qsampler(input_dim=enc_dim, output_dim=z_dim, **inits)

draw = DrawModel(
            n_iter, 
            reader=reader,
            encoder_mlp=encoder_mlp,
            encoder_rnn=encoder_rnn,
            sampler=q_sampler,
            decoder_mlp=decoder_mlp,
            decoder_rnn=decoder_rnn,
            writer=writer)
draw.initialize()

#------------------------------------------------------------------------
x = tensor.matrix('features')
n_samplings = tensor.scalar('n_samplings', dtype=theano.config.floatX)
#y = tensor.matrix('targets')

#x_recons = reconstructed image, i.e. the final canvas, with shape 
#kl_terms = CHECK
#canvas = is the collection of cumulative canvases, i.e. c+=c + self.writer.apply(h_dec)
            # shape (n_iter, batch_size, flattenXdim) 
            # the last is used to calculate x_recons 
#h_enc = encoded vector from encoder RNN, with shape (n_iter, batch_size, enc_dim)
#c_enc = (n_iter, batch_size, enc_dim)
#z = (n_iter, batch_size, z_dim)
    
x_recons, kl_terms, canvas, h_enc, c_enc, z, h_dec, c_dec = draw.reconstructMORE(x)

#mlp = MLP(activations = [Logistic(), Softmax()], dims = [enc_dim,1024,len(authDict)], 
#         **mlpinits)
#mlp.name = 'mlp'
#mlp.initialize()
#pYx = mlp.apply(z[-1])

#soft_term = CategoricalCrossEntropy().apply(y, pYx)
#soft_term.name = 'soft_term'

recons_term = BinaryCrossEntropy().apply(x, x_recons) #+ CategoricalCrossEntropy().apply(y, pYx)
recons_term.name = "recons_term"

cost = recons_term + kl_terms.sum(axis=0).mean() + soft_term 
cost.name = "nll_bound"

In [None]:
cg = ComputationGraph([cost])
#cg.variables

In [None]:
params = VariableFilter(roles=[PARAMETER])(cg.variables)

In [None]:
#updates = RMSprop(params, cost, learning_rate, c=10)
updates = Adam(params, cost, learning_rate, c=10)

In [None]:
if load:
    train = loadModel('/work/notebooks/drawModels/draw_drawlikeIAMTRAIN190000.zip')
    predict = loadModel('/work/notebooks/drawModels/draw_drawlikeIAMPREDICT190000.zip')
    drawRandom = loadModel('/work/notebooks/drawModels/draw_drawlikeIAMDRAWRANDOM190000.zip')
else:
    train = theano.function([x], [cost, canvas, h_enc, c_enc, z, h_dec, c_dec, x_recons], 
                            updates = updates, allow_input_downcast=True)
    predict = theano.function([x], [canvas, h_enc, c_enc, z, h_dec, c_dec, x_recons], allow_input_downcast=True)
    drawRandom = theano.function([], draw.sample(1), allow_input_downcast=True)

In [None]:
holdoutFrag = []
holdoutAuth = []
for hold in [4, 7032, 11133]:
    holdoutFrag.append(frags[hold])
    holdoutAuth.append(frags[hold])

In [None]:
del frags[4]
del frags[7031]
del frags[11131]
del auths[4]
del auths[7031]
del auths[11131]

In [None]:
randIndex = random.sample(xrange(len(frags)), len(frags))
frags = [frags[x] for x in randIndex]
auths = [auths[x] for x in randIndex]

In [None]:
def oneHotOutput(miniAuths, authDict):
    numAuths = len(authDict)
    numInMini = len(miniAuths)
    
    miniOut = np.zeros((numInMini, numAuths), dtype=theano.config.floatX)
    
    for auth in xrange(numInMini):
        miniOut[auth, authDict[miniAuths[auth]]] = 1
        
    return miniOut
    

In [None]:
#min of one image shape is 82
runname = 'binaryBeginnings'
epochCost = []

epochs = 200000
batch_size = 32
iteration = 0

for epoch in xrange(epochs):
    print(' ')
    costCollect = []
    print ("EPOCH: ", epoch)
    #print ('   iteration: ', iteration)
    #if epoch%5 == 0:
        #random.shuffle(frags)
        #print ('data shuffled')
    for start, end in zip(range(0, len(frags),batch_size), range(batch_size, len(frags), batch_size)):
        #inputs,_ = exfromem(trainSet, start, end)
        inputs = frags[start:end]
        patches = get_random_patch(inputs, image_size)
        #outputs = oneHotOutput(auths[start:end], authDict)
        trainOut = train(patches)#, outputs)
        costCollect.append(trainOut[0])
        
        #recon = reconstructed(patches)
    
        if iteration%100==0:
            print ('   iteration: ', iteration)
            print ("   cost: ", trainOut[0])
        #trainOut[cost, canvas, h_enc, c_enc, z, h_dec, c_dec, x_recons]
        
        if iteration%700 == 0:
            patchCollect = []
            #predict = [canvas, h_enc, c_enc, z, h_dec, c_dec, x_recons]
            for i in xrange(10):
                zPatch = get_random_patch(inputs[0], image_size)
                #predictPatch = predict(zPatch)[3][-1]
                patchCollect.append(predict(zPatch.reshape(1,image_size[0]*image_size[1]))[3][-1])
            
            patchCollect = np.mean(np.asarray(patchCollect), axis = 0)
            
            plt.figure(1, figsize = (20,20))
            plt.subplot(1,4,1)
            plt.imshow(patches[0].reshape(*image_size), cmap = 'gray')

            plt.subplot(1,4,2)
            plt.imshow(trainOut[-1][0].reshape(*image_size), cmap = 'gray')
            
            #plt.subplot(1,5,3)
            #plt.imshow(drawLikeX(patches[0].reshape(1,14000))[-1].reshape(*image_size), cmap = 'gray')
            
            plt.subplot(1,4,3)
            plt.imshow(drawRandom()[-1].reshape(*image_size), cmap = 'gray')
            
            plt.subplot(1,4,4)
            plt.imshow(decodering(patchCollect.reshape(1,1,z_dim))[0].reshape(*image_size), cmap = 'gray')
            plt.show()
            
            ####PICKLE MODEL
        if iteration%1000 == 0:
            #predOut = predict(patches)
            #print ("train accuracy: ", np.mean(np.argmax(predOut[0], axis = 1) == np.argmax(outputs, axis = 1)))
            #print ('targets: ', np.argmax(outputs, axis = 1))
            #print ('predictions: ', np.argmax(predOut[0], axis = 1))
            #print ('raw pred: ', predOut[0])
            plt.plot(epochCost[10:])
            plt.show()
            
            pickleModel(train, filePath='/work/notebooks/drawModels/', 
                        fileName = runname+'TRAIN'+str(iteration))
            pickleModel(predict, filePath='/work/notebooks/drawModels/', 
                        fileName = runname+'PREDICT'+str(iteration))
            pickleModel(drawLikeX, filePath='/work/notebooks/drawModels/', 
                        fileName = runname+'DRAWLIKEX'+str(iteration))
            pickleModel(drawRandom, filePath='/work/notebooks/drawModels/', 
                        fileName = runname+'DRAWRANDOM'+str(iteration))
            pickleModel(decodering, filePath='/work/notebooks/drawModels/', 
                        fileName = runname+'DECODERING'+str(iteration))
            
            ####SAVE PNG
                
        iteration+=1
        
    ####SAVE COST TO FILE        
    epochCost.append(np.mean(costCollect))
    np.savetxt(runname+"_COST.csv", epochCost, delimiter=",")
    print ('Epoch cost average:', np.mean(costCollect)) 

In [None]:
def create_reading_square(muX, muY, imgshp = [28,28], fill=True):
        muX = np.clip(muX, 0, imgshp[1]-1)
        muY = np.clip(muY, 0, imgshp[0]-1)
        l = muX.shape[0]
        c = np.zeros([l,imgshp[0],imgshp[1]])
        muX = np.floor(muX)
        muY = np.floor(muY)
        for i in range(l):
            for y in muY[i]:
                for x in muX[i]:
                    c[i,y,x] = 1


        if fill:
            maxX, minX = np.max(muX,axis=1), np.min(muX, axis=1)
            maxY, minY = np.max(muY,axis=1), np.min(muY, axis=1)
            for i in range(l):
                c[i, minY[i]:maxY[i], minX[i]:maxX[i]] = 1

        return c

In [None]:
def plot_n_by_n_images(images,epoch=None,folder=None, n = 10, shp=[28,28]):
    """ Plot 100 MNIST images in a 10 by 10 table. Note that we crop
    the images so that they appear reasonably close together. The
    image is post-processed to give the appearance of being continued."""
    #image = np.concatenate(images, axis=1)
    i = 0
    a,b = shp
    img_out = np.zeros((a*n, b*n))
    for x in range(n):
        for y in range(n):
            xa,xb = x*a, (x+1)*b
            ya,yb = y*a, (y+1)*b
            im = np.reshape(images[i], (a,b))
            img_out[xa:xb, ya:yb] = im
            i+=1
    #matshow(img_out*100.0, cmap = matplotlib.cm.binary)
    img_out = (255*img_out).astype(np.uint8)
    img_out = Image.fromarray(img_out)
    if folder is not None and epoch is not None:
        img_out.save(os.path.join(folder,epoch + ".png"))
    return img_out

In [None]:
class BatchIterator(object):
    """
     Cyclic Iterators over batch indexes. Permutes and restarts at end
    """

    def __init__(self, batch_indices, batchsize, data, testing=False, process_func=None):
        if isinstance(batch_indices, int):
            self.n = batch_indices
            self.batchidx = np.arange(batch_indices)
        else:
            self.n = len(batch_indices)
            self.batchidx = np.array(batch_indices)

        self.batchsize = batchsize
        self.testing = testing
        if process_func is None:
            process_func = lambda x:x
        self.process_func = process_func

        if not isinstance(data, (list, tuple)):
            data = [data]

        self.data = data
        if not self.testing:
            self.createindices = lambda: np.random.permutation(self.n)
        else: # testing == true
            assert self.n % self.batchsize == 0, "for testing n must be multiple of batch size"
            self.createindices = lambda: range(self.n)

        self.perm = self.createindices()
        assert self.n > self.batchsize

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def _get_permuted_batches(self,n_batches):
        # return a list of permuted batch indeces
        batches = []
        for i in range(n_batches):

            # extend random permuation if shorter than batchsize
            if len(self.perm) <= self.batchsize:
                new_perm = self.createindices()
                self.perm = np.hstack([self.perm, new_perm])

            batches.append(self.perm[:self.batchsize])
            self.perm = self.perm[self.batchsize:]
        return batches

    def next(self):
        batch = self._get_permuted_batches(1)[0]   # extract a single batch
        data_batches = [self.process_func(data_n[batch]) for data_n in self.data]
        return data_batches
    
    
def threaded_generator(generator, num_cached=50):
    # this code is writte by jan Schluter
    # copied from https://github.com/benanne/Lasagne/issues/12
    import Queue
    queue = Queue.Queue(maxsize=num_cached)
    sentinel = object()  # guaranteed unique reference

    # define producer (putting items into queue)
    def producer():
        for item in generator:
            queue.put(item)
        queue.put(sentinel)

    # start producer (in a background thread)
    import threading
    thread = threading.Thread(target=producer)
    thread.daemon = True
    thread.start()

    # run as consumer (read items from queue, in current thread)
    item = queue.get()
    while item is not sentinel:
        yield item
        queue.task_done()
        item = queue.get()