In [6]:
!pip install -q tf-models-official==2.7.0

In [5]:
!pip install -q testresources

In [7]:
import os
import shutil

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization  # to create AdamW optimizer

import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

import numpy as np
import pandas as pd

In [9]:
aita_train = pd.read_csv("../data/aita_train.csv")

In [10]:
aita_train.title = aita_train.title.fillna("")
aita_train.text = aita_train.text.fillna("")

In [20]:
aita_train_x = aita_train[['title', 'text']].agg(' [SEP] '.join, axis=1)

In [19]:
aita_train_y = aita_train.flair

In [24]:
n_classes = len(np.unique(aita_train_y))

In [17]:
tfhub_handle_preprocess = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
tfhub_handle_encoder = "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1"

In [14]:
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)

In [15]:
aita_preprocessed = bert_preprocess_model(aita_input[:100])

In [28]:
def build_classifier_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(n_classes - 1, activation='softmax', name='classifier')(net)
    return tf.keras.Model(text_input, net)

In [29]:
classifier_model = build_classifier_model()
bert_raw_result = classifier_model(tf.constant(aita_input[:100]))

<tf.Tensor: shape=(100, 7), dtype=float32, numpy=
array([[0.09873739, 0.4149776 , 0.11706309, 0.1775362 , 0.02471722,
        0.1372928 , 0.0296757 ],
       [0.03865398, 0.57699466, 0.05307658, 0.18564259, 0.02128549,
        0.10765703, 0.01668964],
       [0.11257312, 0.3952986 , 0.26283446, 0.09921625, 0.04017548,
        0.04190616, 0.04799591],
       [0.03908107, 0.49494225, 0.03198351, 0.24051629, 0.0240746 ,
        0.14741789, 0.0219843 ],
       [0.04048315, 0.5204714 , 0.03739785, 0.19000363, 0.01572124,
        0.17658475, 0.01933807],
       [0.06878754, 0.4777585 , 0.12847108, 0.19623636, 0.02889715,
        0.07240301, 0.02744634],
       [0.05758617, 0.5375704 , 0.07005867, 0.16507852, 0.03209681,
        0.10762423, 0.02998531],
       [0.0660321 , 0.49129766, 0.0790427 , 0.24699928, 0.01614131,
        0.08564931, 0.01483771],
       [0.04999016, 0.51675576, 0.06935373, 0.21750031, 0.01561399,
        0.11462566, 0.01616047],
       [0.17862524, 0.31255242, 0.2233710