/
multi_label.py
75 lines (64 loc) · 2.86 KB
/
multi_label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import problem, text_problems, text_encoder
from tensor2tensor.layers import modalities
import tensorflow as tf
class Text2MultiLabelProblem(text_problems.Text2TextProblem):
"""Base class for text multi-labeling problems."""
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate samples of text and label pairs.
Each yielded dict will be a single example. The inputs should be raw text.
The label should be an int array with each element in [0, self.num_classes).
Args:
data_dir: final data directory. Typically only used in this method to copy
over user-supplied vocab files (for example, if vocab_type ==
VocabType.TOKEN).
tmp_dir: temporary directory that you can use for downloading and scratch.
dataset_split: problem.DatasetSplit, which data split to generate samples
for (for example, training and evaluation).
Yields:
{"inputs": text, "labels": int[]}
"""
raise NotImplementedError()
# START: Additional subclass interface
@property
def num_classes(self):
"""The number of classes."""
raise NotImplementedError()
def class_labels(self, data_dir):
"""String representation of the classes."""
del data_dir
return ["ID_%d" % i for i in range(self.num_classes)]
# END: Additional subclass interface
def generate_text_for_vocab(self, data_dir, tmp_dir):
for i, sample in enumerate(
self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)):
yield sample["inputs"]
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab:
break
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
for sample in generator:
inputs = encoder.encode(sample["inputs"])
inputs.append(text_encoder.EOS_ID)
labels = sample["labels"]
yield {"inputs": inputs, "targets": labels}
def feature_encoders(self, data_dir):
encoder = self.get_or_create_vocab(data_dir, None, force_get=True)
return {
"inputs": encoder,
"targets": [text_encoder.ClassLabelEncoder(label) for label in self.class_labels(data_dir)]
}
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.modality = {"inputs": modalities.ModalityType.SYMBOL,
"targets": modalities.ModalityType.MULTI_LABEL}
p.vocab_size = {"inputs": self._encoders["inputs"].vocab_size,
"targets": self.num_classes}
def example_reading_spec(self):
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64),
}
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)