In [1]:
import os
import sys

# workaround to import local modules from parent directory
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import json
import datetime
from transformers import BertTokenizer, TFBertForSequenceClassification, TFTrainer, TFTrainingArguments
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dropout, Dense
import tensorflow as tf
from evaluation import mean_average_precision
from utils import batch_predict

print('Tensorflow Version: {}'.format(tf.__version__))
# Load the TensorBoard notebook extension
%load_ext tensorboard
    
DATA_PATH_FORMATED_TRAIN = '../data/GermanFakeNC_FORMATED_TRAIN.json'
DATA_PATH_FORMATED_TEST = '../data/GermanFakeNC_FORMATED_TEST.json'
MODEL_PATH_BERT = '../models/bert-base-german-cased/'
MODEL_PATH_BERT_TUNED = '../models/bert-base-german-cased-tuned/checkpoint.ckpt'
DATASET_DEV_SPLIT = 0.8
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
BINACC_THRESHOLD = 0.1
PRECISION_RECALL_THRESHOLDS = [0.05, 0.1, 0.2, 0.5]
EPOCHS = 5

tokenizer = BertTokenizer.from_pretrained(MODEL_PATH_BERT)

def load_bert_model():
    cbert_model = TFBertForSequenceClassification.from_pretrained(MODEL_PATH_BERT)
    cbert_model.classifier.activation = tf.keras.activations.sigmoid
    return cbert_model

Tensorflow Version: 2.4.1


In [2]:
def read_data(path):
    with open(path) as json_file:
            return json.load(json_file)
        
def encode(sentences):
    return tokenizer(sentences, max_length=128, truncation=True, padding=True, return_tensors='tf')
        
def to_dataset(data):
    sentences = [d['org'] for d in data]
    encodings_ds = tf.data.Dataset.from_tensor_slices(encode(sentences))    
    encodings_ds = encodings_ds.map(lambda ex: {i:ex[i] for i in ex}) # Batch encoding to dictionary
    labels_ds = tf.data.Dataset.from_tensor_slices([d['lbl'] for d in data]).map(lambda lbl: tf.reshape(lbl, [1]))
    ids_ds = tf.data.Dataset.from_tensor_slices([d['article_id'] for d in data])
    return tf.data.Dataset.zip((ids_ds, encodings_ds, labels_ds))
        
train_data = read_data(DATA_PATH_FORMATED_TRAIN)
test_data = read_data(DATA_PATH_FORMATED_TEST)
    
train_ds = to_dataset(train_data).map(lambda ida, inp, lbl: (inp, lbl))
test_ds = to_dataset(test_data).map(lambda ida, inp, lbl: (ida, inp, lbl[0]))

num_train_examples = int(len(train_data) * DATASET_DEV_SPLIT)
train_ds_split = train_ds.take(num_train_examples)
train_ds_split = train_ds_split.shuffle(100, reshuffle_each_iteration=True).batch(BATCH_SIZE)
dev_ds_split = train_ds.skip(num_train_examples).batch(BATCH_SIZE)

### Load initial pretrained model

In [3]:
cbert_model = load_bert_model()

All model checkpoint layers were used when initializing TFBertForSequenceClassification.

All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at ../models/bert-base-german-cased/.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.


### Load fine-tuned weights

In [4]:
cbert_model.load_weights(MODEL_PATH_BERT_TUNED)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f448f4ec760>

### Training

In [None]:
def get_checkpoint_callback(model_path, monitor_value):
    return tf.keras.callbacks.ModelCheckpoint(model_path, 
                                              save_weights_only=True,
                                              monitor=monitor_value,
                                              verbose=1, 
                                              save_best_only=True,
                                              mode='max')


optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
loss = tf.keras.losses.BinaryCrossentropy()
precision = tf.keras.metrics.Precision(thresholds=PRECISION_RECALL_THRESHOLDS)
recall = tf.keras.metrics.Recall(thresholds=PRECISION_RECALL_THRESHOLDS)
binacc = tf.keras.metrics.BinaryAccuracy(threshold=BINACC_THRESHOLD)
metrics = [precision, recall, binacc]
cbert_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

checkpoint_callback = get_checkpoint_callback(MODEL_PATH_BERT_TUNED, 'val_binary_accuracy')
log_dir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1)

%tensorboard --logdir logs --bind_all
history = cbert_model.fit(train_ds_split,
                epochs=EPOCHS,
                validation_data=dev_ds_split,
                callbacks=[checkpoint_callback, tensorboard_callback])

Epoch 1/5
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method
  3/276 [..............................] - ETA: 2:07:32 - loss: 0.6342 - precision: 0.0365 - precision_1: 0.0365 - binary_accuracy: 0.0486

In [4]:
def prediction_func(inps):
    outputs = cbert_model.predict(inps)
    return [l[0] for l in outputs.logits]

eval_data_bert = batch_predict(test_ds, 100, prediction_func)
print('BERT/MAP: {}'.format(mean_average_precision(eval_data_bert)))

BERT/MAP: 0.45336887554833294


### Results
|     | BERT | |
|-----|------|---------|
| MAP |   0.45336887554833294   |        |
| P@1 |      |         |

In [7]:
for e in eval_data_bert[300:]:
    print(e)

(403, 0.034096032, 0.0)
(403, 0.19367969, 0.0)
(403, 0.12565112, 0.0)
(403, 0.035704404, 0.0)
(403, 0.03312677, 0.0)
(403, 0.04049006, 0.0)
(403, 0.18144748, 1.0)
(403, 0.19050762, 0.0)
(403, 0.058366567, 0.0)
(403, 0.16559353, 0.0)
(403, 0.17332748, 0.0)
(403, 0.18369743, 1.0)
(403, 0.15356398, 0.0)
(403, 0.036789805, 0.0)
(403, 0.05365494, 0.0)
(403, 0.033312917, 0.0)
(403, 0.16307601, 0.0)
(403, 0.18631351, 0.0)
(403, 0.03222406, 0.0)
(403, 0.17469653, 0.0)
(403, 0.18352109, 0.0)
(403, 0.04777342, 0.0)
(403, 0.034762055, 0.0)
(403, 0.032609522, 0.0)
(404, 0.033434033, 0.0)
(404, 0.16163138, 0.0)
(404, 0.18444419, 0.0)
(404, 0.03590226, 0.0)
(404, 0.18821302, 0.0)
(404, 0.124162525, 0.0)
(404, 0.13487288, 0.0)
(404, 0.13059324, 0.0)
(404, 0.03324333, 0.0)
(404, 0.033144325, 0.0)
(404, 0.044724017, 0.0)
(404, 0.15069014, 0.0)
(404, 0.13944602, 0.0)
(404, 0.04741457, 0.0)
(404, 0.16879544, 0.0)
(404, 0.15064421, 0.0)
(404, 0.03439766, 0.0)
(404, 0.16012469, 0.0)
(404, 0.054718316, 0.0)