This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
seq_decoder.py
76 lines (58 loc) · 2.75 KB
/
seq_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Dict, Optional
import torch
from torch.nn import Module
from allennlp.common import Registrable
from allennlp.modules import Embedding
class SeqDecoder(Module, Registrable):
"""
A `SeqDecoder` abstract class representing the entire decoder (embedding and neural network) of
a Seq2Seq architecture.
This is meant to be used with `allennlp.models.encoder_decoder.composed_seq2seq.ComposedSeq2Seq`.
The implementation of this abstract class ideally uses a
decoder neural net `allennlp.modules.seq2seq_decoders.decoder_net.DecoderNet` for decoding.
The `default_implementation`
`allennlp.modules.seq2seq_decoders.seq_decoder.auto_regressive_seq_decoder.AutoRegressiveSeqDecoder`
covers most use cases. More likely that we will use the default implementation instead of creating a new
implementation.
# Parameters
target_embedder : `Embedding`, required
Embedder for target tokens. Needed in the base class to enable weight tying.
"""
default_implementation = "auto_regressive_seq_decoder"
def __init__(self, target_embedder: Embedding) -> None:
super().__init__()
self.target_embedder = target_embedder
def get_output_dim(self) -> int:
"""
The dimension of each timestep of the hidden state in the layer before final softmax.
Needed to check whether the model is compatible for embedding-final layer weight tying.
"""
raise NotImplementedError()
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
"""
The decoder is responsible for computing metrics using the target tokens.
"""
raise NotImplementedError()
def forward(
self,
encoder_out: Dict[str, torch.LongTensor],
target_tokens: Optional[Dict[str, torch.LongTensor]] = None,
) -> Dict[str, torch.Tensor]:
"""
Decoding from encoded states to sequence of outputs
also computes loss if `target_tokens` are given.
# Parameters
encoder_out : `Dict[str, torch.LongTensor]`, required
Dictionary with encoded state, ideally containing the encoded vectors and the
source mask.
target_tokens : `Dict[str, torch.LongTensor]`, optional
The output of `TextField.as_array()` applied on the target `TextField`.
"""
raise NotImplementedError()
def post_process(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Post processing for converting raw outputs to prediction during inference.
The composing models such `allennlp.models.encoder_decoders.composed_seq2seq.ComposedSeq2Seq`
can call this method when `decode` is called.
"""
raise NotImplementedError()