# Handwriting prediction - Model 2

This notebook is a personal attempt at coding Alex Graves RNN to predict handwriting (section 4). The paper can be found [here](https://arxiv.org/abs/1308.0850). It differs from model 1 in structure as it incorporates skip connections. This is an intermediate step before implementing section 5 of the paper. Arguably, it works much better and gives some good results.

The goal of this notebook is to implement a network in a straightforward manner. As such, code readability is a priority over performance. The implemented network consists of layers of LSTM followed by a Gaussian mixtures layer. Handwriting is highly variable. It makes more sense to generate a probability density function at each time step for the next stroke to capture that essence.

The network appears to be working and generates sequences from a starting point that look like handwriting. It is interesting to note that when generating a sequence, the network chooses a style at random and sticks with it.

![example of sample](./pictures/sampleModel2_12.png)
![example of sample](./pictures/sampleModel2_10.png)
![example of sample](./pictures/sampleModel2_6.png)

The network is tweakable in sequence length, number of mixtures and dropout probability.

The notebook is divided into data treatment (I used [Greydanus's code](https://nbviewer.jupyter.org/github/greydanus/scribe/blob/master/dataloader.ipynb) for that as that part is boring, a variation from hardmaru's code), network class, loss function and training. 

The dataset comes from [IAM On-Line Handwriting Database](http://www.fki.inf.unibe.ch/databases/iam-on-line-handwriting-database). Download data/lineStrokes-all.tar.gz after signing up ! The path should be ./data/lineStrokes if you want to use this notebook.

Enjoy :)

In [None]:
import os
import pickle
import random
import xml.etree.ElementTree as ET

import numpy as np
import svgwrite
from IPython.display import SVG, display
import matplotlib.cm as cm

import torch
from torch import nn, optim
import torch.nn.functional as F
use_cuda = False
use_cuda = torch.cuda.is_available()

import time
import matplotlib.pyplot as plt
%matplotlib inline

from ipywidgets import FloatProgress


# Network configuration
n_batch = 20
sequence_length = 300
hidden_size = 256
U_items = int(sequence_length/25)
n_layers = 3
n_gaussians = 20
Kmixtures = 10
gradient_threshold = 10
dropout = 0.2

# Small number to avoid log(0) issue
eps = float(np.finfo(np.float32).eps)

# The network could use the extra space :)
torch.cuda.empty_cache()

## Dataloader
This code comes from [Greydanus](https://nbviewer.jupyter.org/github/greydanus/scribe/blob/master/dataloader.ipynb). Big thanks to his author !

That part is not that fun. Dataloader is a class that parses all the .xml files. It creates a pickle file for future use. It creates a training set containing sequences x, y (same as x but shifted one timestep) and c (one-hot encoding of the sequence) in batches depending on the sequence length. Function `next_batch()` neatly returns a batch. Use `reset_batch_pointer()` to reset the current batch. See the training function for a proper use of that wonderful code. 

In this notebook, we won't use the hot-one vectors as it is used to implement the attention mechanism of section 5 of the paper.

Some examples of training data :

![batch2](./pictures/batch_model2.png)

And some example code to load the data :

```python
x, y, s, c = data_loader.next_batch()
print (data_loader.pointer)
for i in range(n_batch):
    r = x[i]
    strokes = r.copy()
    strokes[:,:-1] = np.cumsum(r[:,:-1], axis=0)
    line_plot(strokes, s[i][:U_items])
```

In [None]:
def get_bounds(data, factor):
    min_x = 0
    max_x = 0
    min_y = 0
    max_y = 0

    abs_x = 0
    abs_y = 0
    for i in range(len(data)):
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        min_x = min(min_x, abs_x)
        min_y = min(min_y, abs_y)
        max_x = max(max_x, abs_x)
        max_y = max(max_y, abs_y)

    return (min_x, max_x, min_y, max_y)

# old version, where each path is entire stroke (smaller svg size, but
# have to keep same color)


def draw_strokes(data, factor=10, svg_filename='sample.svg'):
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = (50 + max_x - min_x, 50 + max_y - min_y)

    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))

    lift_pen = 1

    abs_x = 25 - min_x
    abs_y = 25 - min_y
    p = "M%s,%s " % (abs_x, abs_y)

    command = "m"

    for i in range(len(data)):
        if (lift_pen == 1):
            command = "m"
        elif (command != "l"):
            command = "l"
        else:
            command = ""
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        lift_pen = data[i, 2]
        p += command + str(x) + "," + str(y) + " "

    the_color = "black"
    stroke_width = 1

    dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill("none"))

    dwg.save()
    display(SVG(dwg.tostring()))


def draw_strokes_eos_weighted(
        stroke,
        param,
        factor=10,
        svg_filename='sample_eos.svg'):
    c_data_eos = np.zeros((len(stroke), 3))
    for i in range(len(param)):
        # make color gray scale, darker = more likely to eos
        c_data_eos[i, :] = (1 - param[i][6][0]) * 225
    draw_strokes_custom_color(
        stroke,
        factor=factor,
        svg_filename=svg_filename,
        color_data=c_data_eos,
        stroke_width=3)


def draw_strokes_random_color(
        stroke,
        factor=10,
        svg_filename='sample_random_color.svg',
        per_stroke_mode=True):
    c_data = np.array(np.random.rand(len(stroke), 3) * 240, dtype=np.uint8)
    if per_stroke_mode:
        switch_color = False
        for i in range(len(stroke)):
            if switch_color == False and i > 0:
                c_data[i] = c_data[i - 1]
            if stroke[i, 2] < 1:  # same strike
                switch_color = False
            else:
                switch_color = True
    draw_strokes_custom_color(
        stroke,
        factor=factor,
        svg_filename=svg_filename,
        color_data=c_data,
        stroke_width=2)


def draw_strokes_custom_color(
        data,
        factor=10,
        svg_filename='test.svg',
        color_data=None,
        stroke_width=1):
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = (50 + max_x - min_x, 50 + max_y - min_y)

    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))

    lift_pen = 1
    abs_x = 25 - min_x
    abs_y = 25 - min_y

    for i in range(len(data)):

        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor

        prev_x = abs_x
        prev_y = abs_y

        abs_x += x
        abs_y += y

        if (lift_pen == 1):
            p = "M " + str(abs_x) + "," + str(abs_y) + " "
        else:
            p = "M +" + str(prev_x) + "," + str(prev_y) + \
                " L " + str(abs_x) + "," + str(abs_y) + " "

        lift_pen = data[i, 2]

        the_color = "black"

        if (color_data is not None):
            the_color = "rgb(" + str(int(color_data[i, 0])) + "," + str(
                int(color_data[i, 1])) + "," + str(int(color_data[i, 2])) + ")"

        dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(the_color))
    dwg.save()
    display(SVG(dwg.tostring()))

        
class DataLoader():
    def __init__(self, batch_size=50, tsteps=300, scale_factor = 10, U_items=10, limit = 500, alphabet="default"):
        self.data_dir = "./data"
        self.alphabet = alphabet
        self.batch_size = batch_size
        self.tsteps = tsteps
        self.scale_factor = scale_factor # divide data by this factor
        self.limit = limit # removes large noisy gaps in the data
        self.U_items = U_items

        data_file = os.path.join(self.data_dir, "strokes_training_data_generation.cpkl")
        stroke_dir = self.data_dir+"/lineStrokes"
        ascii_dir = self.data_dir+"/ascii"

        if not (os.path.exists(data_file)) :
            print ("creating training data cpkl file from raw source")
            self.preprocess(stroke_dir, ascii_dir, data_file)

        self.load_preprocessed(data_file)
        self.reset_batch_pointer()

    def preprocess(self, stroke_dir, ascii_dir, data_file):
        # create data file from raw xml files from iam handwriting source.
        print ("Parsing dataset...")
        
        # build the list of xml files
        filelist = []
        # Set the directory you want to start from
        rootDir = stroke_dir
        for dirName, subdirList, fileList in os.walk(rootDir):
#             print('Found directory: %s' % dirName)
            for fname in fileList:
#                 print('\t%s' % fname)
                filelist.append(dirName+"/"+fname)

        # function to read each individual xml file
        def getStrokes(filename):
            tree = ET.parse(filename)
            root = tree.getroot()

            result = []

            x_offset = 1e20
            y_offset = 1e20
            y_height = 0
            for i in range(1, 4):
                x_offset = min(x_offset, float(root[0][i].attrib['x']))
                y_offset = min(y_offset, float(root[0][i].attrib['y']))
                y_height = max(y_height, float(root[0][i].attrib['y']))
            y_height -= y_offset
            x_offset -= 100
            y_offset -= 100

            for stroke in root[1].findall('Stroke'):
                points = []
                for point in stroke.findall('Point'):
                    points.append([float(point.attrib['x'])-x_offset,float(point.attrib['y'])-y_offset])
                result.append(points)
            return result
        
        # function to read each individual xml file
        def getAscii(filename, line_number):
            with open(filename, "r") as f:
                s = f.read()
            s = s[s.find("CSR"):]
            if len(s.split("\n")) > line_number+2:
                s = s.split("\n")[line_number+2]
                return s
            else:
                return ""
                
        # converts a list of arrays into a 2d numpy int16 array
        def convert_stroke_to_array(stroke):
            n_point = 0
            for i in range(len(stroke)):
                n_point += len(stroke[i])
            stroke_data = np.zeros((n_point, 3), dtype=np.int16)

            prev_x = 0
            prev_y = 0
            counter = 0

            for j in range(len(stroke)):
                for k in range(len(stroke[j])):
                    stroke_data[counter, 0] = int(stroke[j][k][0]) - prev_x
                    stroke_data[counter, 1] = int(stroke[j][k][1]) - prev_y
                    prev_x = int(stroke[j][k][0])
                    prev_y = int(stroke[j][k][1])
                    stroke_data[counter, 2] = 0
                    if (k == (len(stroke[j])-1)): # end of stroke
                        stroke_data[counter, 2] = 1
                    counter += 1
            return stroke_data

        # build stroke database of every xml file inside iam database
        strokes = []
        asciis = []
        for i in range(len(filelist)):
            if (filelist[i][-3:] == 'xml'):
                stroke_file = filelist[i]
#                 print 'processing '+stroke_file
                stroke = convert_stroke_to_array(getStrokes(stroke_file))
                
                ascii_file = stroke_file.replace("lineStrokes","ascii")[:-7] + ".txt"
                line_number = stroke_file[-6:-4]
                line_number = int(line_number) - 1
                ascii = getAscii(ascii_file, line_number)
                if len(ascii) > 10:
                    strokes.append(stroke)
                    asciis.append(ascii)
                else:
                    print ("======>>>> Line length was too short. Line was: " + ascii)
                
        assert(len(strokes)==len(asciis)), "There should be a 1:1 correspondence between stroke data and ascii labels."
        f = open(data_file,"wb")
        pickle.dump([strokes,asciis], f, protocol=2)
        f.close()
        print ("Finished parsing dataset. Saved {} lines".format(len(strokes)))


    def load_preprocessed(self, data_file):
        f = open(data_file,"rb")
        [self.raw_stroke_data, self.raw_ascii_data] = pickle.load(f)
        f.close()

        # goes thru the list, and only keeps the text entries that have more than tsteps points
        self.stroke_data = []
        self.ascii_data = []
        counter = 0

        for i in range(len(self.raw_stroke_data)):
            data = self.raw_stroke_data[i]
            if len(data) > (self.tsteps+2):
                # removes large gaps from the data
                data = np.minimum(data, self.limit)
                data = np.maximum(data, -self.limit)
                data = np.array(data,dtype=np.float32)
                data[:,0:2] /= self.scale_factor
                
                self.stroke_data.append(data)
                self.ascii_data.append(self.raw_ascii_data[i])

        # minus 1, since we want the ydata to be a shifted version of x data
        self.num_batches = int(len(self.stroke_data) / self.batch_size)
        print ("Loaded dataset:")
        print ("   -> {} individual data points".format(len(self.stroke_data)))
        print ("   -> {} batches".format(self.num_batches))

    def next_batch(self):
        # returns a randomised, tsteps sized portion of the training data
        x_batch = []
        y_batch = []
        ascii_list = []
        for i in range(self.batch_size):
            data = self.stroke_data[self.idx_perm[self.pointer]]
            x_batch.append(np.copy(data[:self.tsteps]))
            y_batch.append(np.copy(data[1:self.tsteps+1]))
            ascii_list.append(self.ascii_data[self.idx_perm[self.pointer]])
            self.tick_batch_pointer()
        one_hots = [self.one_hot(s) for s in ascii_list]
        return x_batch, y_batch, ascii_list, one_hots
    
    def one_hot(self, s):
        #index position 0 means "unknown"
        if self.alphabet is "default":
            alphabet = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
        seq = [alphabet.find(char) + 1 for char in s]
        if len(seq) >= self.U_items:
            seq = seq[:self.U_items]
        else:
            seq = seq + [0]*(self.U_items - len(seq))
        one_hot = np.zeros((self.U_items,len(alphabet)+1))
        one_hot[np.arange(self.U_items),seq] = 1
        return one_hot

    def tick_batch_pointer(self):
        self.pointer += 1
        if (self.pointer >= len(self.stroke_data)):
            self.reset_batch_pointer()
    def reset_batch_pointer(self):
        self.idx_perm = np.random.permutation(len(self.stroke_data))
        self.pointer = 0
        print ("pointer reset")

## Model

This is where the fun begins. The network consists of LSTM cells stacked on top of each other and followed by a Gaussian mixture layer. This network includes skip connections like the paper. 

`__init__(self, hidden_size = 256, n_gaussians = 10, dropout = 0.2)` :
This is the constructor. It takes the different parameters to create the different blocks of the model.
- hidden_size is the size of the output of each LSTM cell
- n_gaussians is the number of mixtures
- dropout is the dropout probability. It gives the probability to skip a cell during forward propagation. It's not implemented actually

The Gaussian mixtures are created using a dense layer. It takes the output of the last LSTM layer. Say the hidden size is 256 and you want 10 mixtures, this allows to scale your vector to the desired size. This gives ŷ of equation 17 of the paper.
![eq17](./pictures/eq17.png)
ŷ is then broken down into the different parameters of the mixture. 
- ê is the probability of the end of a stroke given by a [Bernoulli distribution](https://en.wikipedia.org/wiki/Bernoulli_distribution)
- w (or $\Pi$ ) is the weight of each Normal distribution
- $\mu, \sigma, \rho$ are the mean, standard deviation and correlation factor of each [bivariate Normal Distribution](http://mathworld.wolfram.com/BivariateNormalDistribution.html)
The constructor juste lays out the blocks but does not create relations between them. That's the job of the forward function.


`forward(self, x, hidden = None)` :
This is the forward propagation. It takes x, a batch of dimensions [sequence_size, batch_size, 3]. The 3 corresponds to x and y offset of a stroke and eos (= 1 when reaching an end of stroke (when the pen is raised)). Note that the forward function is also used to generate random sequences.

The first step is to compute the LSTM cells block. This is straightfoward in PyTorch. Since I created LSTM cells I need a for loop over the whole sequence.

Then it's just a matter of computing 18 - 22 of the paper.
![eq18-22](./pictures/eq18-22.png)


`generate_sequence(self, x0, sequence_length = 100)` :
This is where I clearly sacrifice performance for readability. The goal of this function is to return a sequence based on either a single point or begining of sequence x0. In pseudo-code :
- Calculte the mixture parameters of sequence x0
- Pick a random mixture based on the weights (pi_idx)
- Take a random point from the chosen bivariate normal distribution
- Add it at the end of the sequence (concatenate it)
- Repeat


This clearly is bad practise as I have to rerun the forward prop on the entire sequence each time. And the sequence gets longer and longer which takes more time to compute at each new point generated. However this holds in just a few lines and keeps the forward function cleaner.


`generate_sample(self, mu1, mu2, sigma1, sigma2, rho)` :
Returns random coordinates based on a bivariate normal distribution given by the function parameters. 


In [None]:
class HandwritingGenerationModel2(nn.Module):
    def __init__(self, hidden_size = 256, n_gaussians = 20, dropout = 0.2):
        super(HandwritingGenerationModel2, self).__init__()
        
        self.n_gaussians = n_gaussians
        
        self.hidden_size1 = hidden_size
        self.hidden_size2 = hidden_size
        self.hidden_size3 = hidden_size
        
        # input_size1 includes x, y, eos 
        self.input_size1 = 3
        
        # input_size2 includes x, y, eos and hidden_size1
        self.input_size2 = 3 + self.hidden_size1
        
        # input_size3 includes x, y, eos and hidden_size2
        self.input_size3 = 3 + self.hidden_size2
        
        # Creating the LSTM cells
        self.lstm1 = nn.LSTMCell(input_size= self.input_size1 , hidden_size = self.hidden_size1)
        self.lstm2 = nn.LSTMCell(input_size= self.input_size2 , hidden_size = self.hidden_size2)
        self.lstm3 = nn.LSTMCell(input_size= self.input_size3 , hidden_size = self.hidden_size3)
        
        # For gaussian mixtures
        self.z_e = nn.Linear(hidden_size, 1)
        self.z_pi = nn.Linear(hidden_size, n_gaussians)
        self.z_mu1 = nn.Linear(hidden_size, n_gaussians)
        self.z_mu2 = nn.Linear(hidden_size, n_gaussians)
        self.z_sigma1 = nn.Linear(hidden_size, n_gaussians)
        self.z_sigma2 = nn.Linear(hidden_size, n_gaussians)
        self.z_rho = nn.Linear(hidden_size, n_gaussians)
        
        
    def forward(self, x):
        # sequence length
        sequence_length = x.shape[0]
        
        # number of batches
        n_batch = x.shape[1]
        
        # Hidden and cell state for LSTM1
        h1_t = torch.zeros(n_batch, self.hidden_size1) # torch.Size([n_batch, hidden_size1])
        c1_t = torch.zeros(n_batch, self.hidden_size1) # torch.Size([n_batch, hidden_size1])
        
        # Hidden and cell state for LSTM2
        h2_t = torch.zeros(n_batch, self.hidden_size2) # torch.Size([n_batch, hidden_size2])
        c2_t = torch.zeros(n_batch, self.hidden_size2) # torch.Size([n_batch, hidden_size2])
        
        # Hidden and cell state for LSTM2
        h3_t = torch.zeros(n_batch, self.hidden_size3) # torch.Size([n_batch, hidden_size3])
        c3_t = torch.zeros(n_batch, self.hidden_size3) # torch.Size([n_batch, hidden_size3])
        
        # Outputs of LSTM3 over the whole sequence
        out = torch.zeros(sequence_length, n_batch, self.hidden_size3)
        
        if use_cuda:
            h1_t = h1_t.cuda()
            c1_t = c1_t.cuda()
            
            h2_t = h2_t.cuda()
            c2_t = c2_t.cuda()
            
            h3_t = h3_t.cuda()
            c3_t = c3_t.cuda()
            
            out = out.cuda()

            
        for i in range(sequence_length):
            # ===== Computing 1st layer =====
            h1_t, c1_t = self.lstm1(x[i], (h1_t, c1_t)) # torch.Size([n_batch, hidden_size1])
            
            
            # ===== Computing 2nd layer =====
            input_lstm2 = torch.cat((x[i], h1_t), 1) # torch.Size([n_batch, 3 + hidden_size1])
            h2_t, c2_t = self.lstm2(input_lstm2, (h2_t, c2_t)) 
            
            
            # ===== Computing 3rd layer =====
            input_lstm3 = torch.cat((x[i], h2_t), 1) # torch.Size([n_batch, 3 + alphabet_size + hidden_size2])
            h3_t, c3_t = self.lstm2(input_lstm2, (h3_t, c3_t))
            out[i, :, :] = h3_t
            
            
        # ===== Computing MDN =====
        es = self.z_e(out)
        # print("es shape ", es.shape) # -> torch.Size([sequence_length, batch, 1])
        es = 1 / (1 + torch.exp(es))
        # print("es shape", es.shape) # -> torch.Size([sequence_length, batch, 1])

        pis = self.z_pi(out)
        # print("pis shape ", pis.shape) # -> torch.Size([sequence_length, batch, n_gaussians])
        pis = torch.softmax(pis, 2)
        # print(pis.shape) # -> torch.Size([sequence_length, batch, n_gaussians])

        mu1s = self.z_mu1(out)
        mu2s = self.z_mu2(out)
        # print("mu shape :  ", mu1s.shape) # -> torch.Size([sequence_length, batch, n_gaussians])

        sigma1s = self.z_sigma1(out)
        sigma2s = self.z_sigma2(out)
        # print("sigmas shape ", sigma1s.shape) # -> torch.Size([sequence_length, batch, n_gaussians])
        sigma1s = torch.exp(sigma1s)
        sigma2s = torch.exp(sigma2s)
        # print(sigma1s.shape) # -> torch.Size([sequence_length, batch, n_gaussians])

        rhos = self.z_rho(out)
        rhos = torch.tanh(rhos)
        # print("rhos shape ", rhos.shape) # -> torch.Size([sequence_length, batch, n_gaussians])

        es = es.squeeze(2) 
        # print("es shape ", es.shape) # -> torch.Size([sequence_length, batch])


        return es, pis, mu1s, mu2s, sigma1s, sigma2s, rhos
    
    
    def generate_sample(self, mu1, mu2, sigma1, sigma2, rho):
        mean = [mu1, mu2]
        cov = [[sigma1 ** 2, rho * sigma1 * sigma2], [rho * sigma1 * sigma2, sigma2 ** 2]]
        
        x = np.float32(np.random.multivariate_normal(mean, cov, 1))
        return torch.from_numpy(x)
        
        
    def generate_sequence(self, x0, sequence_length = 100):
        sequence = x0
        
        # A fun little widget
        print("Generating sequence ...")
        f = FloatProgress(min=0, max=sequence_length)
        display(f)
        
        for i in range(sequence_length):
            es, pis, mu1s, mu2s, sigma1s, sigma2s, rhos = self.forward(sequence)
            
            # Selecting a mixture 
            pi_idx = np.random.choice(range(self.n_gaussians), p=pis[-1, 0, :].detach().cpu().numpy())
            
            # taking last parameters from sequence corresponding to chosen gaussian
            mu1 = mu1s[-1, :, pi_idx].item()
            mu2 = mu2s[-1, :, pi_idx].item()
            sigma1 = sigma1s[-1, :, pi_idx].item()
            sigma2 = sigma2s[-1, :, pi_idx].item()
            rho = rhos[-1, :, pi_idx].item()
            
            prediction = self.generate_sample(mu1, mu2, sigma1, sigma2, rho)
            eos = torch.distributions.bernoulli.Bernoulli(torch.tensor([es[-1, :].item()])).sample()
            
            sample = torch.zeros_like(x0) # torch.Size([1, 1, 3])
            sample[0, 0, 0] = prediction[0, 0]
            sample[0, 0, 1] = prediction[0, 1]
            sample[0, 0, 2] = eos
            
            sequence = torch.cat((sequence, sample), 0) # torch.Size([sequence_length, 1, 3])
            
            f.value += 1
        
        return sequence.squeeze(1).detach().cpu().numpy()


## Implementing density probability

It's time to implement the probability density of our next point given our output vector (the Gaussian mixtures parameters). In the paper, this is given by equations 23-25. This will be useful when computing the loss function. 

![eq23-25](./pictures/eq23-25.png)

I chose to exclude the Bernouilli part for now. It will be computed in the loss function.

`gaussianMixture(y, pis, mu1s, mu2s, sigma1s, sigma2s, rhos)` :

Remember the forward function of our model. gaussianMixture(...) takes for parameters its outputs. As such, it computes the results of equation 23 of the whole sequence over the different batches. A note on parameter y. It is basically the same tensor as x but shifted one time step. Think of it as $x_{t+1}$ in equation 23. It allows the last point of a sequence to still be learned correctly.


In [None]:
def gaussianMixture(y, pis, mu1s, mu2s, sigma1s, sigma2s, rhos):
    n_mixtures = pis.size(2)
    
    # Takes x1 and repeats it over the number of gaussian mixtures
    x1 = y[:,:, 0].repeat(n_mixtures, 1, 1).permute(1, 2, 0) 
    # print("x1 shape ", x1.shape) # -> torch.Size([sequence_length, batch, n_gaussians])
    
    # first term of Z (eq 25)
    x1norm = ((x1 - mu1s) ** 2) / (sigma1s ** 2 )
    # print("x1norm shape ", x1.shape) # -> torch.Size([sequence_length, batch, n_gaussians])
    
    x2 = y[:,:, 1].repeat(n_mixtures, 1, 1).permute(1, 2, 0)  
    # print("x2 shape ", x2.shape) # -> torch.Size([sequence_length, batch, n_gaussians])
    
    # second term of Z (eq 25)
    x2norm = ((x2 - mu2s) ** 2) / (sigma2s ** 2 )
    # print("x2norm shape ", x2.shape) # -> torch.Size([sequence_length, batch, n_gaussians])
    
    # third term of Z (eq 25)
    coxnorm = 2 * rhos * (x1 - mu1s) * (x2 - mu2s) / (sigma1s * sigma2s) 
    
    # Computing Z (eq 25)
    Z = x1norm + x2norm - coxnorm
    
    # Gaussian bivariate (eq 24)
    N = torch.exp(-Z / (2 * (1 - rhos ** 2))) / (2 * np.pi * sigma1s * sigma2s * (1 - rhos ** 2) ** 0.5) 
    # print("N shape ", N.shape) # -> torch.Size([sequence_length, batch, n_gaussians]) 
    
    # Pr is the result of eq 23 without the eos part
    Pr = pis * N 
    # print("Pr shape ", Pr.shape) # -> torch.Size([sequence_length, batch, n_gaussians])   
    Pr = torch.sum(Pr, dim=2) 
    # print("Pr shape ", Pr.shape) # -> torch.Size([sequence_length, batch])   
    
    if use_cuda:
        Pr = Pr.cuda()
    
    return Pr
    
    

## Computing loss fn

The goal is to maximize the likelihood of our estimated bivariate normal distributions and Bernoulli distribution. Think about it this way. We generate parameters for our distributions but we want them to fit as best as possible to our data. Each training step's goal is to converge toward the best parameters for our data. [Click here to read more about likelihood function](https://en.wikipedia.org/wiki/Likelihood_function).

In the paper, the loss is given by equation 26 :

![eq26](./pictures/eq26.png)

We previously calculated the first element of the equation in gaussianMixture(...). What's left is to add the Bernoulli loss (second part of our equation). The loss of each time step is summed up and averaged over the batches.



In [None]:
def loss_fn(Pr, y, es):
    loss1 = - torch.log(Pr + eps) # -> torch.Size([sequence_length, batch])    
    bernouilli = torch.zeros_like(es) # -> torch.Size([sequence_length, batch])
    
    bernouilli = y[:, :, 2] * es + (1 - y[:, :, 2]) * (1 - es)
    
    loss2 = - torch.log(bernouilli + eps)
    loss = loss1 + loss2 
    # print("loss shape", loss.shape) # -> torch.Size([sequence_length, batch])  
    loss = torch.sum(loss, 0) 
    # print("loss shape", loss.shape) # -> torch.Size([batch]) 
    
    return torch.mean(loss);
    
    

## Training

The hardest part is behind us ! All that's left is to train our model. I used an Adam optimizer with a learning rate of 0.005. I haven't fiddled around too much with it as it already yields good results. The gradients are clipped inside [-gradient_threshold, gradient_treshold] to avoid exploding gradient. A sequence is generated every 100 batches to see how the model is learning. Looks like it works !

![sample](./pictures/sampleModel2_8.png)
![sample](./pictures/sampleModel2_9.png)
![sample](./pictures/sampleModel2_10.png)
![sample](./pictures/sampleModel2_11.png)

The network is able to pick a style and stick with it. Of course it is unreadable but is is convincing enough.

This is what it looks like training
![training](./pictures/training_model1.png)

And the loss function after 10 epochs. In orange is the average loss per epoch, in blue the loss per batch.
![loss plot](./pictures/model2_loss.png)

In [None]:
def train_network(model, epochs = 5, generate = True):
    data_loader = DataLoader(n_batch, sequence_length, 20) # 20 = datascale
    
    optimizer = optim.Adam(model.parameters(), lr=0.005)
    
    if use_cuda:
        model = model.cuda()
        
    # Arrays to plot loss over time
    time_batch = []
    time_epoch = [0]
    loss_batch = []
    loss_epoch = []
    
    start = time.time()
    
    # Loop over epochs
    for epoch in range(epochs):
        data_loader.reset_batch_pointer()
        
        # Loop over batches
        for batch in range(data_loader.num_batches):
            # Loading a batch
            x, y, _, _ = data_loader.next_batch()
            x = np.float32(np.array(x)) # -> (n_batch, sequence_length, 3)
            y = np.float32(np.array(y)) # -> (n_batch, sequence_length, 3)

            x = torch.from_numpy(x).permute(1, 0, 2) # torch.Size([sequence_length, n_batch, 3])
            y = torch.from_numpy(y).permute(1, 0, 2) # torch.Size([sequence_length, n_batch, 3])
            
            if use_cuda:
                x = x.cuda()
                y = y.cuda()
            
            # Forward pass
            es, pis, mu1s, mu2s, sigma1s, sigma2s, rhos = model.forward(x)
            
            # Calculate probability density and loss
            Pr = gaussianMixture(y, pis, mu1s, mu2s, sigma1s, sigma2s, rhos)
            loss = loss_fn(Pr,y, es)
            
            # Back propagation
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient cliping
            torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_threshold)
            optimizer.step()
            
            # Useful infos over training
            if batch % 20 == 0:
                print("Epoch : ", epoch, " - step ", batch, "/", data_loader.num_batches, " - loss ", loss.item(), " in ", time.time() - start)
                start = time.time()
            
                if generate and batch % 500 == 0:
                    x0 = torch.Tensor([0,0,1]).view(1,1,3)

                    if use_cuda:
                        x0 = x0.cuda()

                    sequence = model.generate_sequence(x0, sequence_length = 500)
                    draw_strokes_random_color(sequence, factor=0.5)
                    
                    
            # Save loss per batch
            time_batch.append(epoch + batch / data_loader.num_batches)
            loss_batch.append(loss.item())
        
        # Save loss per epoch
        time_epoch.append(epoch + 1)
        loss_epoch.append(sum(loss_batch[epoch * data_loader.num_batches : (epoch + 1)*data_loader.num_batches-1]) / data_loader.num_batches)
        
        # Save model after each epoch
        torch.save(model.state_dict(), "./models/prediction_model2.py")
        
    # Plot loss 
    plt.plot(time_batch, loss_batch)
    plt.plot(time_epoch, [loss_batch[0]] + loss_epoch, color="orange", linewidth=5)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("Loss", fontsize=15)
    plt.show()
        
    return model, time_batch, loss_batch, time_epoch, [loss_batch[0]] + loss_epoch

In [None]:
model = HandwritingGenerationModel2(hidden_size, n_gaussians, dropout)
model, time_batch, loss_batch, time_epoch, loss_epoch = train_network(model, epochs=10, generate=True)

## Test cell


```Python
data_loader = DataLoader(n_batch, sequence_length, 20, U_items=U_items) # 20 = datascale

model = HandwritingGenerationModel2(hidden_size, n_gaussians, dropout)

x, y, s, c = data_loader.next_batch()
x = np.float32(np.array(x)) # -> (n_batch, sequence_length, 3)
y = np.float32(np.array(y)) # -> (n_batch, sequence_length, 3)

x = torch.from_numpy(x).permute(1, 0, 2) # torch.Size([sequence_length, n_batch, 3])
y = torch.from_numpy(y).permute(1, 0, 2)

if use_cuda:
    model = model.cuda()
    x = x.cuda()
    y = y.cuda()
    

es, pis, mu1s, mu2s, sigma1s, sigma2s, rhos = model.forward(x)
```