


# Requirements

In [1]:
!pip install svgwrite
import numpy as np
import tensorflow as tf
import sys
from copy import deepcopy

Collecting svgwrite
  Downloading svgwrite-1.4.1-py3-none-any.whl (66 kB)
[?25l[K     |█████                           | 10 kB 38.5 MB/s eta 0:00:01[K     |█████████▉                      | 20 kB 18.7 MB/s eta 0:00:01[K     |██████████████▊                 | 30 kB 15.4 MB/s eta 0:00:01[K     |███████████████████▋            | 40 kB 13.8 MB/s eta 0:00:01[K     |████████████████████████▌       | 51 kB 9.8 MB/s eta 0:00:01[K     |█████████████████████████████▍  | 61 kB 11.3 MB/s eta 0:00:01[K     |████████████████████████████████| 66 kB 4.1 MB/s 
[?25hInstalling collected packages: svgwrite
Successfully installed svgwrite-1.4.1


In [2]:
#!/usr/bin/env python3

import numpy as np
import tensorflow as tf

def positional_encoding(max_seq_len, dm):
    """
    Calculates the positional encoding for a transformer
    max_seq_len: integer representing the maximum sequence length
    dm: integer representing the model depth
    Returns: numpy.ndarray of shape (max_seq_len, dm) containing the positional
             encoding vectors
    """
    PE = np.zeros((max_seq_len, dm))
    for row in range(max_seq_len):
        for col in range(0, dm, 2):
            PE[row, col] = np.sin(row / (10000 ** (col / dm)))
            PE[row, col + 1] = np.cos(row / (10000 ** (col / dm)))
    return PE


def sdp_attention(Q, K, V, mask=None):
    """
    Q: tensor with shape (..., seq_len_q, dk) containing the query matrix
    K: tensor with shape (..., seq_len_v, dk) containing the key matrix
    V: tensor with shape (..., seq_len_v, dv) containing the value matrix
    mask: tensor that can be broadcast into (..., seq_len_q, seq_len_v)
          containing the optional maask, or defaulted to None
    The Preceding dimensions of Q, K, and V are the same
    Returns: output, weights
             output: tensor with shape (..., seq_len_q, dv) containing the dot
                     product attention
             weights: tensor with shape (..., seq_len_q, seq_len_v) containing
                      the attention weights
    """
    # Matmul Q and K
    QK = tf.matmul(Q, K, transpose_b=True)

    # Scale the dot product
    dk = tf.cast(tf.shape(K)[-1], tf.float32)
    scaled = QK / tf.math.sqrt(dk)

    # Add mask if not None
    if mask is not None:
        scaled += mask * -1e9

    # Pass scaled attention through softmax activation
    weights = tf.nn.softmax(scaled, axis=-1)

    # Matmul by value matrix for output
    output = tf.matmul(weights, V)

    return output, weights


class MultiHeadAttention(tf.keras.layers.Layer):
    """
    Class to perform multi head attention
    """
    def __init__(self, dm, h):
        """
        dm: integer representing the model dimensionality
        h: integer representing the number of heads
        dm is divisible by h
        """
        super(MultiHeadAttention, self).__init__()
        self.h = h
        self.dm = dm
        self.depth = dm // self.h
        self.Wq = tf.keras.layers.Dense(dm)
        self.Wk = tf.keras.layers.Dense(dm)
        self.Wv = tf.keras.layers.Dense(dm)
        self.linear = tf.keras.layers.Dense(dm)

    def split_heads(self, x, batch_size):
        """
        Splits the last dimension of tensor x into (h, depth)
        Transpose the result such that the shape is
        (batch_size, h, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.h, self.depth))
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        return x

    def call(self, Q, K, V, mask):
        """
        Q: tensor with shape (..., seq_len_q, dk) containing the query matrix
        K: tensor with shape (..., seq_len_v, dk) containing the key matrix
        V: tensor with shape (..., seq_len_v, dv) containing the value matrix
        mask: always None
        The Preceding dimensions of Q, K, and V are the same
        Returns: output, weights
                 output: tensor with shape (..., seq_len_q, dv) containing the
                         dot product attention
                 weights: tensor with shape (..., seq_len_q, seq_len_v)
                          containing the attention weights
        """
        batch_size = tf.shape(Q)[0]

        # Generate query, key, and value matrices
        Q = self.Wq(Q)
        K = self.Wk(K)
        V = self.Wv(V)

        # Split between heads
        Q = self.split_heads(Q, batch_size)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)

        # Scaled Dot Product Attention
        attention, weights = sdp_attention(Q, K, V, mask)

        # Refit to pass through linear layer
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        attention = tf.reshape(attention, (batch_size, -1, self.dm))
        output = self.linear(attention)

        return output, weights


class DecoderBlock(tf.keras.layers.Layer):
    """
    Class representation of a decoder block for a transformer
    """
    def __init__(self, dm, h, hidden, drop_rate=0.1, name=None):
        """
        dm: Dimensionality of the model
        h: Number of heads
        hidden: Number of hidden units in the fully connected layer
        drop_rate: Dropout rate
        """
        super(DecoderBlock, self).__init__()
        if name is not None:
            self._name = name
        self.mha1 = MultiHeadAttention(dm, h)
        self.dense_hidden = tf.keras.layers.Dense(
            units=hidden,
            activation='relu'
        )
        self.dense_output = tf.keras.layers.Dense(units=dm)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(drop_rate)
        self.dropout2 = tf.keras.layers.Dropout(drop_rate)

    def call(self, inputs, look_ahead_mask, training=False):
        """
        x: tensor of shape (batch, target_seq_len, dm)containing the input to
           the decoder block
        training: boolean to determine if the model is training
        look_ahead_mask: mask to be applied to the first multi head attention
                         layer
        Returns: tensor of shape (batch, target_seq_len, dm) containing the
                 block's output
        """
        # Pass through MHA and dropout layer
        attn_out, _ = self.mha1(inputs, inputs, inputs, look_ahead_mask)
        attn_out = self.dropout1(attn_out, training=training)

        # Add and normalize
        out = self.layernorm1(inputs + attn_out)

        # Pass through dense layers and dropout layer
        dense_output = self.dense_hidden(out)
        dense_output = self.dense_output(dense_output)
        dense_output = self.dropout2(dense_output, training=training)

        # Add and normalize
        out = self.layernorm2(out + dense_output)

        return out


class Decoder(tf.keras.Model):
    """
    A branched Transformer Decoder

    The model uses a linear layer to bring the input feature dimension up to
    the model's dimension before adding fixed sinusoidal positional encoding.

    The model passes the encoded inputs through (Nb) decoder blocks before
    passing their final output to two branches, of (No) and (Np) decoder blocks
    respectively, that predict the X and Y offsets, and pen state
    probabilities, of the next point for each given point in the sequence.

    The offset branch produces an output of size (batch, sequence_len, 2),
    where the feature dimension is [X offset, Y Offset] from the previous point.

    The pen state branch produces an output of size (batch, sequence_len, 3),
    where the feature dimension is [p0, p1, p2], each representing the
    probabilities that the pen will be down, up, or finished, respectively.

    Returns: offsets_predictions, pen_state_predictions
    """
    def __init__(self,
                 Nb,                # Number of blocks in model base
                 No,                # Number of blocks in offset branch
                 Np,                # Number of blocks in pen state branch
                 dm,                # Model dimensionality
                 h,                 # Number of heads used in attention
                 hidden,            # Hidden layer dimenssionality
                 max_seq_len,       # Maximum sequence length
                 drop_rate=0.1):    # Drop rate used in dropout layers

        super(Decoder, self).__init__()
        self.Nb = Nb
        self.No = No
        self.Np = Np
        self.dm = dm
        self.projection = tf.keras.layers.Dense(dm, name='base_projection')
        self.positional_encoding = positional_encoding(max_seq_len, dm)
        self.dropout = tf.keras.layers.Dropout(drop_rate)

        self.base_blocks = [
            DecoderBlock(dm, h, hidden, drop_rate,
            name="base_block_" + str(n)) for n in range(Nb)
        ]
        self.offset_blocks = [
            DecoderBlock(dm, h, hidden, drop_rate,
            name="offset_block_" + str(n)) for n in range(No)
        ]
        self.pen_blocks = [
            DecoderBlock(dm, h, hidden, drop_rate,
            name="pen_block_" + str(n)) for n in range(Np)
        ]

        self.offset_dense = tf.keras.layers.Dense(dm, name='offset_dense')
        self.offset_out = tf.keras.layers.Dense(2, name='offset_out')
        self.pen_dense = tf.keras.layers.Dense(dm, name='pen_dense')
        self.pen_out = tf.keras.layers.Dense(3, name='pen_out',
                                             activation='softmax')

    def call(self,
             inputs,                # Input data
             look_ahead_mask=None,  # Mask used for attention
             training=False):       # Whether the model is training or not

        seq_len = int(inputs.shape[1])

        # Project to model dimension
        x = self.projection(inputs)

        # Add positional encoding and pass through dropout layer
        x *= tf.math.sqrt(tf.cast(self.dm, 'float32'))
        x += self.positional_encoding[:seq_len]
        x = self.dropout(x, training=training)

        # Pass through base decoder blocks
        for block in self.base_blocks:
            x = block(x, look_ahead_mask, training)

        # Pass base output through offset branch
        offset = x
        for block in self.offset_blocks:
            offset = block(offset, look_ahead_mask, training)
        offset = self.offset_dense(offset)
        offset = self.offset_out(offset)

        # Pass base output through pen state branch
        pen = x
        for block in self.pen_blocks:
            pen = block(pen, look_ahead_mask, training)
        pen = self.pen_dense(pen)
        pen = self.pen_out(pen)

        return offset, pen


In [3]:
# Function from utils.py for sketch-rnn in the Magenta github repository
# at https://github.com/magenta/magenta/tree/main/magenta/models/sketch_rnn
def to_big_strokes(stroke, max_len=250):
  """Converts from stroke-3 to stroke-5 format and pads to given length."""
  # (But does not insert special start token).

  result = np.zeros((max_len, 5), dtype=float)
  l = len(stroke)
  assert l <= max_len
  result[0:l, 0:2] = stroke[:, 0:2]
  result[0:l, 3] = stroke[:, 2]
  result[0:l, 2] = 1 - result[0:l, 3]
  result[l:, 4] = 1
  return result

def clean(data, max_length=100):
    """
    Data is a np 3d array of samples in stroke-3 format
    Removes all samples with length > max_length
    Converts to stroke-5 and pads to max_length
    """
    dataset = []
    for sample in data:
        if len(sample) <= max_length:
            sample = to_big_strokes(sample, max_length)
            dataset.append(sample)
    dataset = np.asarray(dataset)
    return dataset


class Dataset:
    """ Loads a numpy.npz file to be used for training """

    def __init__(self,
                 filepath,          # Path to file to load 
                 batch_size=32,     # Batch size to use
                 max_length=250):   # Maximum sequence length per example

        data = np.load(
            filepath,
            encoding='latin1',
            allow_pickle=True
        )

        # Clean up dataset, removing samples over max_length
        self.train = clean(data['train'], max_length)
        self.valid = clean(data['valid'], max_length)
        self.test = clean(data['test'], max_length)

        # Convert to tensorflow datasets for training
        self.train = tf.convert_to_tensor(self.train)
        self.train = tf.data.Dataset.from_tensor_slices(list(self.train))
        self.valid = tf.convert_to_tensor(self.valid)
        self.valid = tf.data.Dataset.from_tensor_slices(list(self.valid))
        self.test = tf.convert_to_tensor(self.test)
        self.test = tf.data.Dataset.from_tensor_slices(list(self.test))

        # Shuffle and batch train and valid sets
        self.train = self.train.shuffle(max_length)
        self.valid = self.valid.shuffle(max_length)
        self.train = self.train.batch(batch_size)
        self.valid = self.valid.batch(batch_size)


In [4]:
def train_model(Nb,           # Number of blocks in model base
                No,           # Number of blocks in offset branch
                Np,           # Number of blocks in pen state branch
                dm,           # Model dimensionality
                h,            # Number of heads used in attention
                hidden,       # Hidden layer dimensionality
                max_len,      # Maximum sequence length
                batch_size,   # Batch size
                epochs,       # Number of epochs to train for
                filepath,     # Path to file to use for training dataset
                verbose=1,    # 0: No printing, 1: Print loss after each epoch,
                              # 2: Print loss every 50 epochs
                weights=None):# Path to weights to use for continuing training
                              # If none, model weights will be initialized
    """
    Creates and trains a model used for predicting future points in an
    unfinished drawing from Google's Quick, Draw! dataset.

    The offset prediction branch is trained using Mean Squared Error loss, and
    the pen state prediction branch is trained using Categorical Crossentropy
    loss.

    The model's weights are saved and overwritten after each epoch if they are
    the best performing at the time.

    Returns: The model, MSE loss history, CCE loss history
    """

    # Load dataset
    data = Dataset(filepath, batch_size=batch_size, max_length=max_len)
    
    # Create model
    model = Decoder(Nb, No, Np, dm, h, hidden, max_len)

    # Run a dummy set of inputs through to initialize weights
    inputs = np.random.uniform(size=(1, max_len, 5))
    model(inputs, None)

    # Load weights if continuing training
    if weights is not None:
        model.load_weights(weights)

    model.summary()

    # Create lists of weights to apply gradients to.
    # Done to separate loss and gradients between offset and pen state 
    # branches, while still applying both to the shared base
    offset_weights = []
    pen_weights = []
    for weight in model.trainable_weights:
        if "base" in weight.name:
            offset_weights.append(weight)
            pen_weights.append(weight)
        if "offset" in weight.name:
            offset_weights.append(weight)
        if "pen" in weight.name:
            pen_weights.append(weight)

    # Loss functions, metrics, learning rate scheduler, optimizers
    offset_loss_func = tf.keras.losses.MeanSquaredError()
    pen_loss_func = tf.keras.losses.CategoricalCrossentropy()

    pen_train_loss = tf.keras.metrics.Mean(name='pen_train_loss')
    offset_train_loss = tf.keras.metrics.Mean(name='offset_train_loss')

    learning_rate = 0.0001
    #learning_rate = CustomSchedule(dm)

    offset_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    pen_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # Sequence length is constant throughout the dataset, so the attention
    # mask can be made ahead of time
    mask = 1 - tf.linalg.band_part(tf.ones((max_len - 1, max_len - 1)), -1, 0)

    # Create lists to store loss histories
    offset_losses = []
    pen_losses = []

    # High arbitrary number to begin comparing best loss to
    prev_best = 10000

    # Define training step
    def train_step(inputs, real):
        """ Single training step """

        # Create gradient tape and get predictions from the model
        with tf.GradientTape(persistent=True) as tape:
            offsets, pen_states = model(inputs, mask, True)

            # Calculate losses
            offset_loss = offset_loss_func(real[:, :, :2], offsets)
            pen_loss = pen_loss_func(real[:, :, 2:], pen_states)

        # Apply gradients to offset branch & base
        grads = tape.gradient(offset_loss, offset_weights)
        offset_optimizer.apply_gradients(zip(grads, offset_weights))

        # Apply gradients to pen state branch & base
        grads = tape.gradient(pen_loss, pen_weights)
        pen_optimizer.apply_gradients(zip(grads, pen_weights))

        # Update loss states
        offset_train_loss(offset_loss)
        pen_train_loss(pen_loss)

        del tape

    # Training Loop
    for epoch in range(epochs):

        # Reset loss metrics at the start of the epoch
        offset_train_loss.reset_states()
        pen_train_loss.reset_states()

        for batch, inp in enumerate(data.train):

            # Target values are input values shifted right by one step
            train_step(inp[:, :-1], inp[:, 1:])

            # Update loss histories
            offset_losses.append(offset_train_loss.result())
            pen_losses.append(pen_train_loss.result())

            if verbose == 2:  # Print results every 50 batches
                if batch % 50 == 0:
                    if batch % 50 == 0:
                        print("Epoch {}, batch {}: Offset Loss: {} Pen Loss {}"
                        .format(
                            epoch + 1,
                            batch,
                            offset_train_loss.result(),
                            pen_train_loss.result()
                        ))

        if verbose >= 1:  # Print results after each epoch
            print("Epoch {}: Offset Loss: {:.4f} Pen Loss {:.4f}".format(
                epoch + 1,
                offset_train_loss.result(),
                pen_train_loss.result()
            ))

        # Save best performing weights
        if offset_train_loss.result() < prev_best:
            model.save_weights('50_epoch_best.h5')
            prev_best = offset_train_loss.result()

    return model, offset_losses, pen_losses


class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    """ Custom learning rate schedule """
    def __init__(self, d_model, warmup_steps=25000):
        """ Init """
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        """ Call """
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [6]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

large_model, offset_losses, pen_losses = train_model(
    8,              # Base blocks
    4,              # Offset blocks
    4,              # Pen state blocks
    128,            # Model dimensionality
    8,              # Heads
    512,            # Hidden units
    100,            # Max sequence length
    64,             # Batch size
    0,              # Epochs
    'cat.npz',      # File path
    2,              # Verbosity
    'Double size 50 epochs.h5')# Weights to load if continuing training


small_model, offset_losses, pen_losses = train_model(
    4,              # Base blocks
    2,              # Offset blocks
    2,              # Pen state blocks
    128,            # Model dimensionality
    8,              # Heads
    512,            # Hidden units
    100,            # Max sequence length
    64,             # Batch size
    0,              # Epochs
    'cat.npz',      # File path
    2,              # Verbosity
    '160 Epochs.h5')# Weights to load if continuing training


Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
base_projection (Dense)      multiple                  768       
_________________________________________________________________
dropout (Dropout)            multiple                  0         
_________________________________________________________________
base_block_0 (DecoderBlock)  multiple                  198272    
_________________________________________________________________
base_block_1 (DecoderBlock)  multiple                  198272    
_________________________________________________________________
base_block_2 (DecoderBlock)  multiple                  198272    
_________________________________________________________________
base_block_3 (DecoderBlock)  multiple                  198272    
_________________________________________________________________
base_block_4 (DecoderBlock)  multiple                  1982

In [7]:
# libraries required for visualisation:
import os
import svgwrite
import numpy as np
import tensorflow as tf
from IPython.display import SVG, display
import PIL
from PIL import Image
import matplotlib.pyplot as plt
# import data_Manager
import math
from matplotlib import animation

# data_Manager = Data
# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)


def get_bounds(data, factor=10):
    """Return bounds of data."""
    min_x = 0
    max_x = 0
    min_y = 0
    max_y = 0

    abs_x = 0
    abs_y = 0
    for i in range(len(data)):
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        min_x = min(min_x, abs_x)
        min_y = min(min_y, abs_y)
        max_x = max(max_x, abs_x)
        max_y = max(max_y, abs_y)

    return (min_x, max_x, min_y, max_y)


def slerp(p0, p1, t):
    """Spherical interpolation."""
    omega = np.arccos(np.dot(p0 / np.linalg.norm(p0), p1 / np.linalg.norm(p1)))
    so = np.sin(omega)
    return np.sin((1.0 - t) * omega) / so * p0 + np.sin(t * omega) / so * p1


def lerp(p0, p1, t):
    """Linear interpolation."""
    return (1.0 - t) * p0 + t * p1


def to_normal_strokes(big_stroke):
    """Convert from stroke-5 format to stroke-3."""
    l = 0
    for i in range(len(big_stroke)):
        if big_stroke[i, 4] > 0:
            l = i
            break
    if l == 0:
        l = len(big_stroke)
    result = np.zeros((l, 3))
    result[:, 0:2] = big_stroke[0:l, 0:2]
    result[:, 2] = big_stroke[0:l, 3]
    return result


# little function that displays vector images and saves them to .svg
def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
    # data = data_Manager.to_normal_strokes(data)
    data = to_normal_strokes(data)
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = (50 + max_x - min_x, 50 + max_y - min_y)
    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
    lift_pen = 1
    abs_x = 25 - min_x 
    abs_y = 25 - min_y
    p = "M%s,%s " % (abs_x, abs_y)
    command = "m"
    for i in range(len(data)):
        if (lift_pen == 1):
            command = "m"
        elif (command != "l"):
            command = "l"
        else:
            command = ""
        x = float(data[i,0])/factor
        y = float(data[i,1])/factor
        lift_pen = data[i, 2]
        p += command+str(x)+","+str(y)+" "
    the_color = "black"
    stroke_width = 3
    dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
    dwg.save()
    display(SVG(dwg.tostring()))


"""
Function for animate drawing. 
taken from 
https://colab.research.google.com/github/zaidalyafeai/Notebooks/blob/master/Strokes_QuickDraw.ipynb#scrollTo=0ABX6O4kYwYS
"""
def create_animation(drawing, fps = 30, idx = 0, lw = 5): 
  
  seq_length = 0 
  
  xmax = 0 
  ymax = 0 
  
  xmin = math.inf
  ymin = math.inf
  
  #retreive min,max and the length of the drawing  
  for k in range(0, len(drawing)):
    x = drawing[k][0]
    y = drawing[k][1]

    seq_length += len(x)
    xmax = max([max(x), xmax]) 
    ymax = max([max(y), ymax]) 
    
    xmin = min([min(x), xmin]) 
    ymin = min([min(y), ymin]) 
    
  i = 0 
  j = 0
  
  # First set up the figure, the axis, and the plot element we want to animate
  fig = plt.figure()
  ax = plt.axes(xlim=(xmax+lw, xmin-lw), ylim=(ymax+lw, ymin-lw))
  ax.set_facecolor("white")
  line, = ax.plot([], [], lw=lw)

  #remove the axis 
  ax.grid = False
  ax.set_xticks([])
  ax.set_yticks([])
  
  # initialization function: plot the background of each frame
  def init():
      line.set_data([], [])
      return line, 

  # animation function.  This is called sequentially
  def animate(frame):    
    nonlocal i, j, line
    x = drawing[i][0]
    y = drawing[i][1]
    line.set_data(x[0:j], y[0:j])
    
    if j >= len(x):
      i +=1
      j = 0 
      line, = ax.plot([], [], lw=lw)
      
    else:
      j += 1
    return line,
  
  # call the animator.  blit=True means only re-draw the parts that have changed.
  anim = animation.FuncAnimation(fig, animate, init_func=init,
                                 frames= seq_length + len(drawing), blit=True)
  plt.close()
  
  # save the animation as an mp4.  
  anim.save(f'video.mp4', fps=fps, extra_args=['-vcodec', 'libx264'])

In [8]:
# Load dataset
def to_big_strokes(stroke, max_len=250):
  """Converts from stroke-3 to stroke-5 format and pads to given length."""
  # (But does not insert special start token).

  result = np.zeros((max_len, 5), dtype=float)
  l = len(stroke)
  assert l <= max_len
  result[0:l, 0:2] = stroke[:, 0:2]
  result[0:l, 3] = stroke[:, 2]
  result[0:l, 2] = 1 - result[0:l, 3]
  result[l:, 4] = 1
  return result

def clean(data, max_length=100):
    """
    Data is a np 3d array of samples in stroke-3 format
    Removes all samples with length > max_length
    Converts to stroke-5 and pads to max_length
    """
    dataset = []
    for sample in data:
        if len(sample) <= max_length:
            sample = to_big_strokes(sample, max_length)
            dataset.append(sample)
    dataset = np.asarray(dataset)
    return dataset

def create_mask(batch_size, seq_len):
        """
        Creates the look mask for attention in the decoder
        seq_len: Length of the sequence for which to make the mask
        """
        
        mask = 1 - tf.linalg.band_part(tf.ones((1, 1, seq_len, seq_len)), -1, 0)
        return mask


def predict(model, input):
    # loop until max seq length, or drawing is finished
    inputs = input[np.newaxis, :]
    while inputs[0, -1, -1] != 1 and inputs.shape[1] < 100:
        mask = create_mask(1, inputs.shape[1])
        offsets, pen_states = model(inputs, mask)
        offsets = np.round(offsets)
        pen_states = np.round(pen_states)
        pred = np.concatenate((offsets[0, -1], pen_states[0, -1]))
        inputs = np.concatenate((inputs, pred.reshape(1, 1, 5)), axis=1)

    return inputs[0]

data = np.load('cat.npz', encoding='latin1', allow_pickle=True)
data = clean(data['test'], 100)


# Demos

In [None]:
# Full circle input

input = deepcopy(data[0, :18])

complete = predict(large_model, input[:18])

input[:, :2] /= 3
complete[:, :2] /= 3

draw_strokes(input[:18], svg_filename='input.svg')
draw_strokes(complete, svg_filename="cat.svg")

In [None]:
# Half circle input

input = deepcopy(data[0, :11])

complete = predict(large_model, input[:18])

input[:, :2] /= 3
complete[:, :2] /= 3

draw_strokes(input[:11], svg_filename='input.svg')
draw_strokes(complete, svg_filename="cat.svg")

In [None]:
# Start from ears

drawing = deepcopy(data[21])

input = drawing[:15]

complete = predict(large_model, input)

drawing[:, :2] /= 3
complete[:, :2] /= 3

draw_strokes(input, svg_filename="cat.svg")
draw_strokes(complete, svg_filename="cat.svg")

In [None]:
# Generate from nothing

input = np.asarray([[0, 0, 1, 0, 0]])

complete = predict(small_model, input)

complete[:, :2] /= 3

draw_strokes(complete, svg_filename="cat.svg")

In [None]:
# Comparison to full drawing

drawing = deepcopy(data[2010])

input = drawing[:20]

complete = predict(large_model, input)

drawing[:, :2] /= 3
complete[:, :2] /= 3

draw_strokes(drawing, svg_filename="cat.svg")
draw_strokes(input, svg_filename="cat.svg")
draw_strokes(complete, svg_filename="cat.svg")