# Sequence-to-sequence (seq2seq) model
This is based on paper "Attention is all you need".

Applications include machine translation.

Seq2seq models follows the basic architecture of autoencoder, which uses multi-head attention mechanism.

In the following, I will implement seq2seq model based on my own understanding step by step. This can be different from the model described in the seq2seq model paper and tensorflow seq2seq package.

In [1]:
import re
import collections

import numpy as np

import torch
import torch.nn as nn

use_gpu = True
if use_gpu and torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

## Multi-head attention

In [199]:
def get_list(x, n, forced=False):
  """Expand x to a list of n
  If x is already a list of length n and force is False, then do nothing. 
  Otherwise, return [x]*n
  """
  if forced or not isinstance(x, collections.Iterable):
    return [x]*n
  if len(x) != n:
    return [x]*n
  else:
    return x

class MultiLayerPerceptron(nn.Module):
  r"""Multi-layer perceptron
  Args:
    in_dim: int, input feature dimension
    hidden_dims: sequence of int, hidden dimensions; hidden_dims[-1] is the output dimension
    nonlinearity: default nn.ReLU(); can be changed to other nonlinearities
    dense: if True, use DenseNet architecture, 
      i.e., concatenate all previous output as current input
    residual: if True, use ResNet architecture. That's to say:
      add the weighted average (decided by weighted_avg function) of previous outputs (after activations) to current affine output, 
      and then pass it to nonlinearity
    last_nonlinearity: default False; if True, add nonlinearity to output layer.
    forward_input: if True, forward input layer to subsequential layers when either dense or residual is True
    return_all: If True, return a list of output; otherwise, return the output from the last layer
    residual_mode: default 'last', only add the previous (nonlinear) output to current affine output;
      only used when residual is True; If 'weighted', then calculate a default weighted average of all previous outputs
      and add it to current affine output before passing it to nonlinearity. 
      To do: make the weight learnable
  
  Shape:
    Input: (N, *, in_dim)
    Output: if return_all is False, then (N, *, out_dim); else return a list of tensors 
    (depend on forward_input, dense, and residual)
  
  Attributes:
    A list of weights and biases from nn.Linear modules; the dimensions depend on in_dims, hidden_dims,
      dense, residual, forward_input
  
  Examples:
  
    >>> x = torch.randn(3,4,5)
    >>> model = MultiLayerPerceptron(5, [5,5,5], dense=True)
    >>> model(x).shape
  
  """
  def __init__(self, in_dim, hidden_dims, nonlinearity=nn.ReLU(inplace=True), bias=True, 
               dense=True, residual=False, last_nonlinearity=False, forward_input=False, return_all=False, 
               residual_mode='last'):
    super(MultiLayerPerceptron, self).__init__()
    num_layers = len(hidden_dims)
    self.dense = dense
    self.residual = residual
    self.last_nonlinearity = last_nonlinearity
    self.forward_input = forward_input
    self.return_all = return_all
    self.residual_mode = residual_mode
    # make sure the dimensions are right
    assert not (dense and residual)
    if residual:
      for i in range(1, num_layers):
        assert hidden_dims[i]==hidden_dims[i-1]
      if forward_input:
        assert in_dim==hidden_dims[0]
    
    # nonlinearity and bias can be a set layer by layer by providing a list input
    nonlinearities = get_list(nonlinearity, num_layers if last_nonlinearity else num_layers-1)
    biases = get_list(bias, num_layers)
    
    self.layers = nn.Sequential()
    for i in range(num_layers):
      out_dim = hidden_dims[i]
      self.layers.add_module('linear{}'.format(i), nn.Linear(in_dim, out_dim, bias=biases[i]))
      if i < num_layers-1 or last_nonlinearity:
        self.layers.add_module('activation{}'.format(i), nonlinearities[i])
      # prepare for input dimension for next layer
      if dense:
        if i==0 and not forward_input:
          in_dim = 0
        in_dim += hidden_dims[i]
      else:
        in_dim = hidden_dims[i]
  
  def weighted_avg(self, y, mode='last', weight=None):
    if mode == 'last':
      return y[-1]
    if model == 'unweighted':
      return torch.cat(y, dim=-1).mean(-1)
    if mode == 'weighted':
      if weight is None:
        weight = torch.tensor([i for i in range(1, len(y)+1)], device=device)
      weight = weight / weight.sum()
      return (torch.cat(y, dim=-1) * weight).sum(-1)
    
  def forward(self, x):
    if self.forward_input:
      y = [x]
    else:
      y = []
    out = x
    for n, m in self.layers._modules.items():
      out = m(out)
      if n.startswith('activation'):
        y.append(out)
        if self.dense:
          out = torch.cat(y, dim=-1)
      if n.startswith('linear') and self.residual and len(y)>0:
        out = out + self.weighted_avg(y, mode=self.residual_mode)
    
    if self.return_all:
      if not self.last_nonlinearity:
        y.append(out)
      return y
    else:
      return out


class MultiHeadAttention(nn.Module):
  r"""Use multi-head self attention mechanism to learn sequence embedding.
  Args:
    in_dim: int; input feature dimension
    out_dim: int; output feature dimension
    key_dim: int or a list of num_heads int; 
      map input to (key, value) and use keys to calculate weights (attention) 
    value_dim: int or a list of num_heads int; if None, set it to be out_dim
    num_heads: int
    mask: if True, each element in a sequence only attends to itself and its left side; 
      useful for decoder
    knn: int; only attend to the top k elements with the highest unnormalized attention
    
  Shape:
    Input: (N, seq_len, in_dim) for most cases; (N, *, seq, in_dim) is also possible
    Output: change the last dimension of input to out_dim
  
  Attributes:
    In the end all parameters are from nn.Linear modules. 
    keys and values are two nn.ModuleList instances with num_heads of two-layer perceptron
    In the final layer we concatenate feature vector from all heads and pass it to a two-layer
    perceptron and get the final output
  
  Examples:
  
    >>> x = torch.randn(3,5,7)
    >>> model = MultiHeadAttention(7,11,13,17,19, mask=False, knn=1)
    >>> model(x).shape
  
  """
  def __init__(self, in_dim, out_dim, key_dim, value_dim=None, 
               num_heads=1, mask=False, knn=None):
    super(MultiHeadAttention, self).__init__()
    if value_dim is None:
      value_dim = out_dim
    self.num_heads = num_heads
    self.mask = mask
    self.knn = knn
    key_dims = get_list(key_dim, num_heads)
    value_dims = get_list(value_dim, num_heads)
    self.keys = nn.ModuleList([MultiLayerPerceptron(in_dim, [key_dims[i]]*2) 
                               for i in range(num_heads)])
    self.values = nn.ModuleList([MultiLayerPerceptron(in_dim, [value_dims[i]]*2) 
                                 for i in range(num_heads)])
    self.out = MultiLayerPerceptron(sum(value_dims), [out_dim]*2)
    
  def forward(self, x):
    y = []
    for i in range(self.num_heads):
      keys = self.keys[i](x)
      values = self.values[i](x)
      # inner product as unnormalized attention
      att = (keys.unsqueeze(-2) * keys.unsqueeze(-3)).sum(-1) 
      if self.mask:
        # mask the upper triangle to be float('-Inf')
        tmp = att.new_tensor(range(att.size(-1))).expand_as(att)
        idx = torch.nonzero(tmp > tmp.transpose(-1,-2))
        idx = [idx[:,i] for i in range(idx.size(1))]
        att[idx] = float('-Inf')
      if self.knn and self.knn < att.size(-1):
        # tricky: put Non-topk values to '-Inf'
        att.scatter_(-1, att.topk(att.size(-1) - self.knn, -1, largest=False)[1], float('-Inf'))
      # Use softmax to normalize attention; # To do: alternative to softmax
      att = torch.nn.functional.softmax(att, dim=-1)
      # tricky: 
      y.append((values.unsqueeze(-3) * att.unsqueeze(-1)).sum(-2))
      
    return self.out(torch.cat(y, dim=-1))
  

class Encoder(nn.Module):
  """Stacked MultiHeadAttention layers
  """
  def __init__(self, num_layers, in_dim, out_dim, key_dim, value_dim=None, num_heads=1, 
               mask=False, knn=None, residual=True, normalization='layer_norm', return_all=False):
    super(Encoder, self).__init__()
    # We can set out_dim layer by layer, similar to key_dim, value_dim, num_heads
    out_dim = [in_dim] + get_list(out_dim, num_layers)
    key_dim = get_list(key_dim, num_layers)
    value_dim = get_list(value_dim, num_layers)
    num_heads = get_list(num_heads, num_layers)
    self.num_layers = num_layers
    self.residual = residual
    self.return_all = return_all
    self.normalization = normalization
    if residual:
      for i in range(num_layers):
        assert out_dim[i] == out_dim[i+1]
    self.attentions = nn.ModuleList([MultiHeadAttention(
      out_dim[i], out_dim[i+1], key_dim[i], value_dim[i], num_heads[i], mask, knn) 
                                     for i in range(num_layers)])
    self.perceptrons = nn.ModuleList([MultiLayerPerceptron(out_dim[i+1], [out_dim[i+1]]*2) 
                                     for i in range(num_layers)])
  
  def forward(self, x):
    y = []
    out = x
    for i in range(self.num_layers):
      if self.residual:
        out = self.attentions[i](out) + out
      else:
        out = self.attentions[i](out)
      if self.normalization == 'layer_norm':
        out = nn.functional.layer_norm(out, (out.size(-1),))
      # perceptron
      if self.residual:
        out = self.perceptrons[i](out) + out
      else:
        out = self.perceptrons[i](out)
      if self.normalization == 'layer_norm':
        out = nn.functional.layer_norm(out, (out.size(-1),))
      y.append(out)
    if self.return_all:
      return y
    return out

In [198]:
x = torch.randn(3,5,7)

model = Encoder(1, 7, 7, 11, mask=True, knn=3)
model(x).shape

torch.Size([3, 5, 7])

In [181]:
x = torch.randn(3,5,5)
idx = torch.nonzero(x > x.transpose(-1,-2))

In [190]:
[idx[:, i] for i in range(idx.size(1))]

[tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2]),
 tensor([0, 1, 1, 2, 2, 2, 3, 4, 4, 4, 0, 0, 0, 1, 1, 2, 3, 3, 4, 4, 0, 0, 0, 1,
         1, 2, 2, 3, 3, 4]),
 tensor([3, 0, 3, 0, 1, 3, 4, 0, 1, 2, 1, 2, 3, 2, 3, 4, 2, 4, 0, 1, 1, 2, 4, 3,
         4, 1, 3, 0, 4, 2])]