Skip to content

Commit

Permalink
Transformer: move Transformer related classes into onmt/modules/Trans…
Browse files Browse the repository at this point in the history
…former.py

1) Move TransformerDecoder and TransformerDecoderState to
   onmt/modules/Transformer.py.
2) Rename the original TransformerDecoder to TransformerDecoderLayer.
3) Move DecoderState to onmt/Translator.py, otherwise we will circular dependency
   problem between onmt/Models.py and onmt/modules/Transformer.py.
  • Loading branch information
JianyuZhan committed Sep 10, 2017
1 parent 5c6e8f8 commit 82a07b3
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 157 deletions.
158 changes: 6 additions & 152 deletions onmt/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack

import onmt
import onmt.modules
from onmt.IO import ONMTDataset
from onmt.modules import aeq
from onmt.modules.Gate import ContextGateFactory
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
from onmt.Translator import DecoderState
from onmt.modules.Transformer import TransformerDecoder, \
TransformerDecoderState


def build_embeddings(opt, word_pad_ix, feat_pad_ix, num_word_embeddings,
Expand Down Expand Up @@ -447,114 +451,6 @@ def _input_size(self):
return self.embeddings.embedding_dim + self.hidden_size


class TransformerDecoder(nn.Module):
"""
Transformer Decoder. Wrapper around
onmt.modules.TransformerDecoder.
"""

def __init__(self, num_layers, hidden_size, attn_type,
copy_attn, dropout, embeddings):
"""
See make_decoder() comment for arguments description.
"""
super(TransformerDecoder, self).__init__()

# Basic attributes.
self.decoder_type = 'transformer'
self.num_layers = num_layers
self.embeddings = embeddings

# Build TransformerDecoder.
self.transformer = nn.ModuleList(
[onmt.modules.TransformerDecoder(hidden_size, dropout)
for _ in range(num_layers)])

# TransformerDecoder has its own attention mechanism.
# Set up a separated copy attention layer, if needed.
self._copy = False
if copy_attn:
self.copy_attn = onmt.modules.GlobalAttention(
hidden_size, attn_type=attn_type)
self._copy = True

def forward(self, input, context, state):
"""
Forward through the TransformerDecoder.
Args:
input (LongTensor): a sequence of input tokens tensors
of size (len x batch x nfeats).
context (FloatTensor): output(tensor sequence) from the Encoder
RNN of size (src_len x batch x hidden_size).
state (FloatTensor): hidden state from the Encoder RNN for
initializing the decoder.
Returns:
outputs (FloatTensor): a Tensor sequence of output from the Decoder
of shape (len x batch x hidden_size).
state (FloatTensor): final hidden state from the Decoder.
attns (dict of (str, FloatTensor)): a dictionary of different
type of attention Tensor from the Decoder
of shape (src_len x batch).
"""
# CHECKS
assert isinstance(state, TransformerDecoderState)
input_len, input_batch, _ = input.size()
contxt_len, contxt_batch, _ = context.size()
aeq(input_batch, contxt_batch)

if state.previous_input is not None:
input = torch.cat([state.previous_input, input], 0)

src = state.src
src_words = src[:, :, 0].transpose(0, 1)
tgt_words = input[:, :, 0].transpose(0, 1)
src_batch, src_len = src_words.size()
tgt_batch, tgt_len = tgt_words.size()
aeq(input_batch, contxt_batch, src_batch, tgt_batch)
aeq(contxt_len, src_len)
# aeq(input_len, tgt_len)
# END CHECKS

# Initialize return variables.
outputs = []
attns = {"std": []}
if self._copy:
attns["copy"] = []

# Run the forward pass of the TransformerDecoder.
emb = self.embeddings(input)
assert emb.dim() == 3 # len x batch x embedding_dim

output = emb.transpose(0, 1).contiguous()
src_context = context.transpose(0, 1).contiguous()

padding_idx = self.embeddings.padding_idx
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
.expand(src_batch, tgt_len, src_len)
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
.expand(tgt_batch, tgt_len, tgt_len)

for i in range(self.num_layers):
output, attn \
= self.transformer[i](output, src_context,
src_pad_mask, tgt_pad_mask)

# Process the result and update the attentions.
outputs = output.transpose(0, 1).contiguous()
if state.previous_input is not None:
outputs = outputs[state.previous_input.size(0):]
attn = attn[:, state.previous_input.size(0):].squeeze()
attn = torch.stack([attn])
attns["std"] = attn
if self._copy:
attns["copy"] = attn

# Update the TransformerDecoderState.
state = TransformerDecoderState(src, input)

return outputs, state, attns


class CNNDecoder(nn.Module):
"""
CNN Decoder. Wrapper around onmt.modules.ConvDecoder.
Expand Down Expand Up @@ -717,24 +613,6 @@ def forward(self, src, tgt, lengths, dec_state=None):
return out, attns, dec_state


class DecoderState(object):
def detach(self):
for h in self.all:
if h is not None:
h.detach_()

def repeatBeam_(self, beamSize):
self._resetAll([Variable(e.data.repeat(1, beamSize, 1))
for e in self.all])

def beamUpdate_(self, idx, positions, beamSize):
for e in self.all:
a, br, d = e.size()
sentStates = e.view(a, beamSize, br // beamSize, d)[:, :, idx]
sentStates.data.copy_(
sentStates.data.index_select(1, positions))


class RNNDecoderState(DecoderState):
def __init__(self, rnnstate, input_feed=None, coverage=None):
"""
Expand Down Expand Up @@ -768,30 +646,6 @@ def _resetAll(self, all):
self.all = self.hidden + (self.input_feed,)


class TransformerDecoderState(DecoderState):
def __init__(self, src, input=None):
"""
Args:
src (FloatTensor): a sequence of source words tensors
with optional feature tensors, of size (len x batch).
input (LongTensor): a sequence of input tokens tensors
of size (len x batch).
"""
self.src = src
self.previous_input = input
self.all = (self.previous_input, self.src)

def _resetAll(self, all):
vars = [(Variable(a.data if isinstance(a, Variable) else a,
volatile=True))
for a in all]
self.previous_input = vars[0]
self.all = (self.previous_input,)

def repeatBeam_(self, beamSize):
self.src = Variable(self.src.data.repeat(1, beamSize, 1))


class CNNDecoderState(DecoderState):
def __init__(self, input=None):
self.init_src = None
Expand Down
22 changes: 22 additions & 0 deletions onmt/Translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
from torch.autograd import Variable


class DecoderState(object):
"""
DecoderState is a base class for models, used during translation
for storing translation states.
"""
def detach(self):
for h in self.all:
if h is not None:
h.detach_()

def repeatBeam_(self, beamSize):
self._resetAll([Variable(e.data.repeat(1, beamSize, 1))
for e in self.all])

def beamUpdate_(self, idx, positions, beamSize):
for e in self.all:
a, br, d = e.size()
sentStates = e.view(a, beamSize, br // beamSize, d)[:, :, idx]
sentStates.data.copy_(
sentStates.data.index_select(1, positions))


class Translator(object):
def __init__(self, opt, dummy_opt={}):
# Add in default model arguments, possibly added since training.
Expand Down
143 changes: 138 additions & 5 deletions onmt/modules/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

import onmt.modules
from onmt.Translator import DecoderState
from onmt.modules import aeq


Expand Down Expand Up @@ -64,9 +67,9 @@ def forward(self, input, mask):
return out


class TransformerDecoder(nn.Module):
class TransformerDecoderLayer(nn.Module):
"""
The Transformer Decoder from paper "Attetion is all you need".
The Transformer Decoder Layer from "Attetion is all you need".
"""
def __init__(self, size, dropout,
head_count=8, hidden_size=2048):
Expand All @@ -79,7 +82,7 @@ def __init__(self, size, dropout,
head_count(int): the number of head for MultiHeadedAttention.
hidden_size(int): the second-layer of the PositionwiseFeedForward.
"""
super(TransformerDecoder, self).__init__()
super(TransformerDecoderLayer, self).__init__()
self.self_attn = onmt.modules.MultiHeadedAttention(
head_count, size, p=dropout)
self.context_attn = onmt.modules.MultiHeadedAttention(
Expand All @@ -89,8 +92,8 @@ def __init__(self, size, dropout,
dropout)
self.dropout = dropout
mask = self._get_attn_subsequent_mask(MAX_SIZE)
# Register self.mask as a buffer in TransformerDecoder, so
# it gets TransformerDecoder's cuda behavior automatically.
# Register self.mask as a buffer in TransformerDecoderLayer, so
# it gets TransformerDecoderLayer's cuda behavior automatically.
self.register_buffer('mask', mask)

def forward(self, input, context, src_pad_mask, tgt_pad_mask):
Expand Down Expand Up @@ -133,3 +136,133 @@ def _get_attn_subsequent_mask(self, size):
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
subsequent_mask = torch.from_numpy(subsequent_mask)
return subsequent_mask


class TransformerDecoder(nn.Module):
"""
Transformer Decoder.
"""
def __init__(self, num_layers, hidden_size, attn_type,
copy_attn, dropout, embeddings):
"""
See make_decoder() comment for arguments description.
"""
super(TransformerDecoder, self).__init__()

# Basic attributes.
self.decoder_type = 'transformer'
self.num_layers = num_layers
self.embeddings = embeddings

# Build TransformerDecoder.
self.transformer_layers = nn.ModuleList(
[TransformerDecoderLayer(hidden_size, dropout)
for _ in range(num_layers)])

# TransformerDecoder has its own attention mechanism.
# Set up a separated copy attention layer, if needed.
self._copy = False
if copy_attn:
self.copy_attn = onmt.modules.GlobalAttention(
hidden_size, attn_type=attn_type)
self._copy = True

def forward(self, input, context, state):
"""
Forward through the TransformerDecoder.
Args:
input (LongTensor): a sequence of input tokens tensors
of size (len x batch x nfeats).
context (FloatTensor): output(tensor sequence) from the Encoder
RNN of size (src_len x batch x hidden_size).
state (FloatTensor): hidden state from the Encoder RNN for
initializing the decoder.
Returns:
outputs (FloatTensor): a Tensor sequence of output from the Decoder
of shape (len x batch x hidden_size).
state (FloatTensor): final hidden state from the Decoder.
attns (dict of (str, FloatTensor)): a dictionary of different
type of attention Tensor from the Decoder
of shape (src_len x batch).
"""
# CHECKS
assert isinstance(state, TransformerDecoderState)
input_len, input_batch, _ = input.size()
contxt_len, contxt_batch, _ = context.size()
aeq(input_batch, contxt_batch)

if state.previous_input is not None:
input = torch.cat([state.previous_input, input], 0)

src = state.src
src_words = src[:, :, 0].transpose(0, 1)
tgt_words = input[:, :, 0].transpose(0, 1)
src_batch, src_len = src_words.size()
tgt_batch, tgt_len = tgt_words.size()
aeq(input_batch, contxt_batch, src_batch, tgt_batch)
aeq(contxt_len, src_len)
# aeq(input_len, tgt_len)
# END CHECKS

# Initialize return variables.
outputs = []
attns = {"std": []}
if self._copy:
attns["copy"] = []

# Run the forward pass of the TransformerDecoder.
emb = self.embeddings(input)
assert emb.dim() == 3 # len x batch x embedding_dim

output = emb.transpose(0, 1).contiguous()
src_context = context.transpose(0, 1).contiguous()

padding_idx = self.embeddings.padding_idx
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
.expand(src_batch, tgt_len, src_len)
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
.expand(tgt_batch, tgt_len, tgt_len)

for i in range(self.num_layers):
output, attn \
= self.transformer_layers[i](output, src_context,
src_pad_mask, tgt_pad_mask)

# Process the result and update the attentions.
outputs = output.transpose(0, 1).contiguous()
if state.previous_input is not None:
outputs = outputs[state.previous_input.size(0):]
attn = attn[:, state.previous_input.size(0):].squeeze()
attn = torch.stack([attn])
attns["std"] = attn
if self._copy:
attns["copy"] = attn

# Update the TransformerDecoderState.
state = TransformerDecoderState(src, input)

return outputs, state, attns


class TransformerDecoderState(DecoderState):
def __init__(self, src, input=None):
"""
Args:
src (FloatTensor): a sequence of source words tensors
with optional feature tensors, of size (len x batch).
input (LongTensor): a sequence of input tokens tensors
of size (len x batch).
"""
self.src = src
self.previous_input = input
self.all = (self.previous_input, self.src)

def _resetAll(self, all):
vars = [(Variable(a.data if isinstance(a, Variable) else a,
volatile=True))
for a in all]
self.previous_input = vars[0]
self.all = (self.previous_input,)

def repeatBeam_(self, beamSize):
self.src = Variable(self.src.data.repeat(1, beamSize, 1))

0 comments on commit 82a07b3

Please sign in to comment.