-
Notifications
You must be signed in to change notification settings - Fork 0
/
bert_classifier.py
49 lines (37 loc) · 2.23 KB
/
bert_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
from tensorflow.python.ops.array_ops import unstack
import tensorflow_addons as tfa
from bert import PretrainerBERT
class ClassifierBERTv2(tf.keras.models.Model):
def __init__(self, num_class, num_layers, vocab_size, seq_len, hidden_size, dff, num_heads, dropout_rate=0.1):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.fc = tf.keras.layers.Dense(num_class, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
def call(self, prediction, training=True):
prediction = self.dropout(prediction, training=training)
prediction = self.fc(prediction)
return prediction
class ClassifierBERT(tf.keras.models.Model):
def __init__(self, num_class, num_layers, vocab_size, seq_len, hidden_size, dff, num_heads, dropout_rate=0.1):
super().__init__()
self.dense = tf.keras.layers.Dense(hidden_size, activation='tanh', kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.fc = tf.keras.layers.Dense(num_class, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
def call(self, prediction, training=True):
prediction = self.dense(prediction)
#prediction = tf.nn.tanh(prediction)
prediction = self.dropout(prediction, training=training)
prediction = self.fc(prediction)
return prediction
class SquadBERT(tf.keras.models.Model):
def __init__(self, num_layers, vocab_size, seq_len, hidden_size, dff, num_heads, dropout_rate=0.1):
super().__init__()
self.dense = tf.keras.layers.Dense(hidden_size, activation='gelu', kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.fc = tf.keras.layers.Dense(2, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
def call(self, prediction, training=True):
prediction = self.dense(prediction)
prediction = self.dropout(prediction, training=training)
prediction = self.fc(prediction)
prediction = tf.transpose(prediction, perm=[2, 0, 1])
return prediction[0], prediction[1]