This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
decoder_net.py
98 lines (78 loc) · 3.74 KB
/
decoder_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Tuple, Dict, Optional
import torch
from allennlp.common import Registrable
class DecoderNet(torch.nn.Module, Registrable):
"""
This class abstracts the neural architectures for decoding the encoded states and
embedded previous step prediction vectors into a new sequence of output vectors.
The implementations of `DecoderNet` is used by implementations of
`allennlp.modules.seq2seq_decoders.seq_decoder.SeqDecoder` such as
`allennlp.modules.seq2seq_decoders.seq_decoder.auto_regressive_seq_decoder.AutoRegressiveSeqDecoder`.
The outputs of this module would be likely used by `allennlp.modules.seq2seq_decoders.seq_decoder.SeqDecoder`
to apply the final output feedforward layer and softmax.
# Parameters
decoding_dim : `int`, required
Defines dimensionality of output vectors.
target_embedding_dim : `int`, required
Defines dimensionality of target embeddings. Since this model takes it's output on a previous step
as input of following step, this is also an input dimensionality.
decodes_parallel : `bool`, required
Defines whether the decoder generates multiple next step predictions at in a single `forward`.
"""
def __init__(
self, decoding_dim: int, target_embedding_dim: int, decodes_parallel: bool
) -> None:
super().__init__()
self.target_embedding_dim = target_embedding_dim
self.decoding_dim = decoding_dim
self.decodes_parallel = decodes_parallel
def get_output_dim(self) -> int:
"""
Returns the dimension of each vector in the sequence output by this `DecoderNet`.
This is `not` the shape of the returned tensor, but the last element of that shape.
"""
return self.decoding_dim
def init_decoder_state(
self, encoder_out: Dict[str, torch.LongTensor]
) -> Dict[str, torch.Tensor]:
"""
Initialize the encoded state to be passed to the first decoding time step.
# Parameters
batch_size : `int`
Size of batch
final_encoder_output : `torch.Tensor`
Last state of the Encoder
# Returns
`Dict[str, torch.Tensor]`
Initial state
"""
raise NotImplementedError()
def forward(
self,
previous_state: Dict[str, torch.Tensor],
encoder_outputs: torch.Tensor,
source_mask: torch.BoolTensor,
previous_steps_predictions: torch.Tensor,
previous_steps_mask: Optional[torch.BoolTensor] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Performs a decoding step, and returns dictionary with decoder hidden state or cache and the decoder output.
The decoder output is a 3d tensor (group_size, steps_count, decoder_output_dim)
if `self.decodes_parallel` is True, else it is a 2d tensor with (group_size, decoder_output_dim).
# Parameters
previous_steps_predictions : `torch.Tensor`, required
Embeddings of predictions on previous step.
Shape: (group_size, steps_count, decoder_output_dim)
encoder_outputs : `torch.Tensor`, required
Vectors of all encoder outputs.
Shape: (group_size, max_input_sequence_length, encoder_output_dim)
source_mask : `torch.BoolTensor`, required
This tensor contains mask for each input sequence.
Shape: (group_size, max_input_sequence_length)
previous_state : `Dict[str, torch.Tensor]`, required
previous state of decoder
# Returns
Tuple[Dict[str, torch.Tensor], torch.Tensor]
Tuple of new decoder state and decoder output. Output should be used to generate out sequence elements
"""
raise NotImplementedError()