forked from domanjiri/joint-bert-with-tf2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
127 lines (99 loc) · 4.59 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import tensorflow as tf
from transformers import TFBertModel
import config
from preprocess import Process
class CustomBertLayer(tf.keras.layers.Layer):
"""Custom layer to modify build and call methods of BERT if needed.
"""
def __init__(self, **kwargs):
super(CustomBertLayer, self).__init__(**kwargs)
self._bert = self._load_bert()
def _load_bert(self):
model = TFBertModel.from_pretrained(config.bert_model_name)
logging.info('BERT weights loaded')
return model
def build(self, input_shape):
super(CustomBertLayer, self).build(input_shape)
def call(self, inputs):
result = self._bert(inputs=inputs)
return result
class CustomModel(tf.keras.Model):
"""Definition of the model to modify with custom call method.
Args:
intents_num(int):
Number of intents in the working dataset that used in softmax layer.
slots_num(int):
Number of slots labels in the working dataset that used in softmax layer.
"""
def __init__(self,
intents_num : int,
slots_num : int):
super().__init__(name="joint_intent_slot")
self._bert_layer = CustomBertLayer()
self._dropout = tf.keras.layers.Dropout(rate=config.dropout_rate)
self._intent_classifier = tf.keras.layers.Dense(intents_num,
activation='softmax',
name='intent_classifier')
self._slot_classifier = tf.keras.layers.Dense(slots_num,
activation='softmax',
name='slot_classifier')
def call(self, inputs, training=False, **kwargs):
sequence_output, pooled_output = self._bert_layer(inputs, **kwargs)
sequence_output = self._dropout(sequence_output, training)
slot_logits = self._slot_classifier(sequence_output)
pooled_output = self._dropout(pooled_output, training)
intent_logits = self._intent_classifier(pooled_output)
return slot_logits, intent_logits
class JointCategoricalBert(object):
"""Wrapper to model functions. The Model compiles with hyper-parameters and
will be ready for fit.
Args:
train(preprocess.Process):
Holds the training part of samples.
validation(preprocess.Process):
Holds the validation part of samples.
intents_num(int):
Number of intents in the working dataset which will be used in softmax layer.
slots_num(int):
Number of slot lables in the working dataset which will be used in softmax layer.
"""
def __init__(self,
train : Process,
validation : Process,
intents_num : int,
slots_num : int):
self._dataset = {'train': train, 'validation': validation}
self._model = CustomModel(intents_num=intents_num, slots_num=slots_num)
self._compile()
def _compile(self):
"""Compile the model with hyper-parameters that defined in the config file.
"""
optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
losses = [tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)]
loss_weights = [config.loss_weights['slot'], config.loss_weights['intent']]
metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]
self._model.compile(optimizer=optimizer,
loss=losses,
loss_weights=loss_weights,
metrics=metrics)
logging.info("model compiled")
def fit(self):
"""Fit the compiled model to the dataset. Hyper-parameters such as number of
epochs defined in the config file.
"""
logging.info('before fit model')
self._model.fit(
self._dataset['train'].get_tokens(),
(self._dataset['train'].get_slots(), self._dataset['train'].get_intents()),
validation_data=(
self._dataset['validation'].get_tokens(),
(self._dataset['validation'].get_slots(),
self._dataset['validation'].get_intents())),
epochs=config.epochs_num,
batch_size=config.batch_size)
return self._model