In [1]:
from utils import h36motion3d as datasets
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch.optim as optim
import torch.autograd
import torch
import numpy as np
from utils.loss_funcs import *
from utils.data_utils import define_actions
from utils.h36_3d_viz import visualize
import time
import math

import torch.nn.functional as F

#reproducibility stuff
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

torch.backends.cudnn.deterministic=True

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device: %s'%device)

Using device: cpu


In [3]:
class Attention(nn.Module):
    def __init__(self, attn_dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, query, key, value, mask=None):
        d_k = key.size(-1) #-1 refers to the last dimension
        # Compute the dot product of the query and key, and scale it
        #attn = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) #Note: Transpose(-2, -1) swaps the second last dimension with the last dimension
        attn = torch.einsum('BHLD, BHMD -> BHLM', query, key) / math.sqrt(d_k)

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(value, dim=-1))
        
        #output = torch.matmul(attn, value)
        output = torch.einsum('BHLM, BHLD -> BHLD', attn, value)

        return output, attn


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model) #Initialize the positional encoding max_len*d_model vector with zeros
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) #Initialize a max_len vector
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) #Compute the denominator of the pe fixed formula
        pe[:, 0::2] = torch.sin(position*div_term) #sin if position index is even
        pe[:, 1::2] = torch.cos(position*div_term) #cos if position index is odd
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :] #Simply concatenate the input x with the positional encoding vectors
        return self.dropout(x)

In [5]:
class EncoderLayer(nn.Module):
    '''Class that defines a single block of the encoder transformer (see the structure on notes page 53)'''
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        #Note: Layer norm and residual connection are applied in the inner classes (MultiHeadAttention and PositionwiseFeedForward)
        enc_output, enc_slf_atten = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) #Apply the multi-head self-attention. Assume the input is an embedding with time signal.
        enc_output = self.pos_ffn(enc_output) #Apply the feed-forward network
        return enc_output, enc_slf_atten #Return the output of the two operations

In [6]:
class DecoderLayer(nn.Module):
    '''Class that defines a single block of the encoder transformer (see the structure on notes page 54)'''
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        #Note: Layer norm and residual connection are applied in the inner classes (MultiHeadAttention and PositionwiseFeedForward)
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) #First (Masked) Multi-head self-attention of the decoder
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) #Second Multi-head self-attention that takes into consideration the output of the encoder
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
         #Apply the (masked) multi-head self-attention. Here, k, v and q comes from the decoder (the already predicted units)
        dec_output, dec_slf_attn = self.slf_attn(dec_input, dec_input, dec_input, mask=slf_attn_mask)
        #Apply the (masked) multi-head self-attention. Here, k, v comes from the encoder (the memory - output of the encoder) while q comes from the decoder (the already predicted units)
        dec_output, dec_enc_attn = self.enc_attn(dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output = self.pos_ffn(dec_output) #Apply the feed forward network
        return dec_output, dec_slf_attn, dec_enc_attn #return the output of (of the 3) each layer

Define the MultiHeadAttention class like this:
\begin{split}
    \text{Multihead}(Q,K,V) & = \text{Concat}(\text{head}_1,...,\text{head}_h)W^{O}\\
    \text{where } \text{head}_i & = \text{Attention}(QW_i^Q,KW_i^K, VW_i^V)
\end{split}

In [9]:
#TO COMPLETE!

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k # Dimension of the keys
        self.d_v = d_v # Dimension of the values

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) #Compute W_q (query weights)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) #Computer W_k (key weights)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) #Compute W_v (Value weights)
        #self.fc = nn.Linear(n_head * d_v, d_model, bias=False) #MLP to be applied after layer norm

        #Note on dimensionality: The dimensionality for both keys and queries is N*d_k, that's why we use the same variable for define the weights
        #Concerning the values, is d_v and it defines the dimensionality of the final output

        self.attention = Attention()

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        def forward(self, q, k, v, mask=None):
            d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
            sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

            residual = q #Save the queries in order to compute the residual connection after the Multi-head Attention
            
            # Pass through the pre-attention projection: b x lq x (n*dv)
            # Separate different heads: b x lq x n x dv (!!!)
            q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
            k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
            v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

            # Transpose for attention dot product: b x n x lq x dv
            q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

            if mask is not None:
                mask = mask.unsqueeze(1)   # For head axis broadcasting.

            q, attn = self.attention(q, k, v, mask=mask) #Apply the dot product operation
            # Transpose to move the head dimension back: b x lq x n x dv
            # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
            q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
            q = self.dropout(self.fc(q))
            q += residual #Compute residual connection

            q = self.layer_norm(q) #Apply layer norm

            return q, attn


In [11]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_in, d_hidden, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(d_in, d_hidden) # position-wise
        self.w2 = nn.Linear(d_hidden, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        residual = x #We save the input before applying the feed-forward operation and after, we sum the residual

        x = self.w2(F.relu(self.w1(x))) #Apply the feed-forward pass (w1 -> ReLU -> w2)
        x = self.dropout(x)
        x += residual #Compute residual connection

        x = self.layer_norm(x)

        return x

In [12]:
def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)


def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''
    if (len(seq.shape) == 4):
      sz_b, c, len_s, dim = seq.size()
    else:
      sz_b, len_s, dim = seq.size()
    subsequent_mask = (1 - torch.triu(
        torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
    return subsequent_mask

In [13]:
class Encoder(nn.Module):
    ''' A encoder model with self attention mechanism. '''

    def __init__(
            self, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False):

        super().__init__()


        self.position_enc = PositionalEncoding(d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)]) #Put all Encoder layers in the stack
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        #self.scale_emb = scale_emb
        #self.d_model = d_model

    def forward(self, src_seq, src_mask, return_attns=False):

        enc_slf_attn_list = []

        enc_output = self.dropout(self.position_enc(src_seq)) #Apply positional encoding to the input
        enc_output = self.layer_norm(enc_output) #Apply layer norm

        for enc_layer in self.layer_stack: #For each Encoder layer
            enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) #Apply it
            enc_slf_attn_list += [enc_slf_attn] if return_attns else []

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output

In [None]:
class Decoder(nn.Module):
    ''' A decoder model with self attention mechanism. '''

    def __init__(
            self, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, n_position=200, dropout=0.1, scale_emb=False):

        super().__init__()

        self.position_enc = PositionalEncoding(d_word_vec)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)]) #Put all Encoder layers in the stack
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.scale_emb = scale_emb
        self.d_model = d_model

    def forward(self, trg_seq, trg_mask, enc_output, src_mask, inference):
        if not inference: #If you are training the model..
          dec_output = self.dropout(self.position_enc(trg_seq))
          dec_output = self.layer_norm(dec_output)

          for dec_layer in self.layer_stack: #For each Decoder layer
              dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                  dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) #Apply it
          return dec_output

        else: #If you are testing the model (during inference)
          for i in range(trg_seq.shape[1]): 
            #print(trg_seq[0, :, 3])
            if i!= trg_seq.shape[1]-1:
              dec_output = self.dropout(self.position_enc(trg_seq))
              dec_output = self.layer_norm(dec_output)

              for dec_layer in self.layer_stack:
                  dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                      dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)
              trg_seq[:, i+1, :] = dec_output[:, -1, :]
            else:
              dec_output = self.dropout(self.position_enc(trg_seq))
              dec_output = self.layer_norm(dec_output)

              for dec_layer in self.layer_stack:
                  dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                      dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)

          return torch.cat((trg_seq[:, 1:, :], dec_output[:, -1, :].unsqueeze(1)), dim=1)