# BERT Trading Signal: News and Corporate Actions

```bibtex
@inproceedings{zhou-etal-2021-trade,
    title = "Trade the Event: Corporate Events Detection for News-Based Event-Driven Trading",
    author = "Zhou, Zhihan  and
      Ma, Liqian  and
      Liu, Han",
    editor = "Zong, Chengqing  and
      Xia, Fei  and
      Li, Wenjie  and
      Navigli, Roberto",
    booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
    month = aug,
    year = "2021",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.findings-acl.186",
    doi = "10.18653/v1/2021.findings-acl.186",
    pages = "2114--2124",
}
```

# Notebook Environment

In [None]:
UPGRADE_PY = False
INSTALL_DEPS = False
if INSTALL_DEPS:
  # %pip install -q tensorboard==2.15.2
  # %pip install -q tensorflow[and-cuda]==2.15.1
  # %pip install -q tensorflow==2.15.0
  # %pip install -q tensorflow-io-gcs-filesystem==0.36.0
  # %pip install -q tensorflow-text==2.15.0
  # %pip install -q tf_keras==2.15.1
  # %pip install -q tokenizers==0.15.2
  # %pip install -q torch==2.2.0+cpu
  # %pip install -q torch-xla==2.2.0+libtpu
  # %pip install -q torchdata==0.7.1
  %pip install -q transformers==4.38.2

if UPGRADE_PY:
  !mamba create -n py311 -y
  !source /opt/conda/bin/activate py312 && mamba install python=3.11 jupyter mamba -y

  !sudo rm /opt/conda/bin/python3
  !sudo ln -sf /opt/conda/envs/py312/bin/python3 /opt/conda/bin/python3
  !sudo rm /opt/conda/bin/python3.10
  !sudo ln -sf /opt/conda/envs/py312/bin/python3 /opt/conda/bin/python3.10
  !sudo rm /opt/conda/bin/python
  !sudo ln -sf /opt/conda/envs/py312/bin/python3 /opt/conda/bin/python

!python --version

In [None]:
import os
import sys
import warnings
warnings.filterwarnings("ignore")

# Transformers cannot use keras3
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_USE_LEGACY_KERAS'] = '1'
IN_KAGGLE = IN_COLAB = False
!export CUDA_LAUNCH_BLOCKING=1
!export XLA_FLAGS=--xla_cpu_verbose=0

MODEL_PATH = "google-bert/bert-base-cased"
try:
    # https://www.tensorflow.org/install/pip#windows-wsl2
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive')
    DATA_PATH = "/content/drive/MyDrive/EDT dataset"
    DATA_PATH = "/content/drive/MyDrive/investopediaBERT"
    IN_COLAB = True
    print('Colab!')
except:
  IN_COLAB = False
if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ and not IN_COLAB:
    print('Running in Kaggle...')
    for dirname, _, filenames in os.walk('/kaggle/input'):
        for filename in filenames:
            print(os.path.join(dirname, filename))
    DATA_PATH = "/kaggle/input/uscorpactionnews"
    MODEL_PATH = "/kaggle/input/finbert/tensorflow2/basevocab-uncased-conditioned-investopedia/1/models"
    IN_KAGGLE = True
    print('Kaggle!')
elif not IN_COLAB and not IN_KAGGLE:
    IN_KAGGLE = False
    DATA_PATH = "./data/"
    print('Localhost!')
    MODEL_PATH = "./models/conditioned"


# Accelerators Configuration

In [None]:
import numpy as np
import pandas as pd

from pathlib import Path
import re
import pickle
from copy import deepcopy

from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras import mixed_precision

print(f'Tensorflow version: [{tf.__version__}]')

tf.get_logger().setLevel('INFO')

#tf.config.set_soft_device_placement(True)
#tf.config.experimental.enable_op_determinism()
#tf.random.set_seed(1)
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()

  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.TPUStrategy(tpu)
except Exception as e:
  # Not an exception, just no TPUs available, GPU is fallback
  # https://www.tensorflow.org/guide/mixed_precision
  print(e)
  policy = mixed_precision.Policy('mixed_float16')
  mixed_precision.set_global_policy(policy)
  gpus = tf.config.experimental.list_physical_devices('GPU')
  if len(gpus) > 0:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, False)
        tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=12288)])
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        strategy = tf.distribute.MirroredStrategy()

        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
    finally:
        print("Running on", len(tf.config.list_physical_devices('GPU')), "GPU(s)")
  else:
    # CPU is final fallback
    strategy = tf.distribute.get_strategy()
    print("Running on CPU")

def is_tpu_strategy(strategy):
    return isinstance(strategy, tf.distribute.TPUStrategy)

print("Number of accelerators:", strategy.num_replicas_in_sync)
os.getcwd()

# Tokens, Sequences, and NER

Our corpus will be processed and labelled to 11 types of corporate events:
1. Acquisition(A)
1. Clinical Trial(CT)
1. Regular Dividend(RD)
1. Dividend Cut(DC)
1. Dividend Increase(DI)
1. Guidance Increase(GI)
1. New Contract(NC)
1. Reverse Stock Split(RSS)
1. Special Dividend(SD)
1. Stock Repurchase(SR)
1. Stock Split(SS).
1. No event (O)

Articles are structured as follows:

```json
'title': 'Title',
'text': 'Text Body',
'pub_time': 'Published datetime',
'labels': {
    'ticker': 'Security symbol',
    'start_time': 'First trade after article published',
    'start_price_open': 'The "Open" price at start_time',
    'start_price_close': 'The "Close" price at start_time',
    'end_price_nday': 'The "Close" price at the last minute of the following 1-3 trading day. If early than 4pm ET its the same day. Otherwise, it refers to the next trading day.',
    'end_time_1-3day': 'The time corresponds to end_price_1day',
    'highest_price_nday': 'The highest price in the following 1-3 trading',
    'highest_time_nday': 'The time corresponds to highest_price_1-3day',
    'lowest_price_nday': 'The lowest price in the following 1-3 trading day',
    'lowest_time_nday': 'The time corresponds to lowest_price_1-3day',
}
```

In [None]:
NUM_LABELS = 12 # See Labels description above.
SPECIAL_TOKEN = 'CLS' # Use for classification and hidden state placeholder.
UNK_ID = -100 # Unknown token, ignored by loss
UNK = 'UNK'
OTHER_ID = 11
OTHER = 'O'

### Tokenizing News Text

In [None]:
from transformers import BertTokenizerFast, TFBertModel, BertConfig

# https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#berttokenizerfast
tokenizer = BertTokenizerFast.from_pretrained(f'{MODEL_PATH}/tokenizer')
model = TFBertModel.from_pretrained(f'{MODEL_PATH}/model')

text = ["When taken as a whole, the evidence suggests Cramer recommends “hot” stocks", "lending credence to the Hot Hand Fallacy in this context."]

tokenized_sequence = tokenizer.tokenize(text)
print(tokenized_sequence)

In [None]:
MAX_LEN = 256 # Default 256, MAX 512
sample_inputs = inputs = tokenizer.encode_plus(
    text,
    add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
    max_length=MAX_LEN,  # Maximum length for padding/truncation, adjust as needed
    padding='max_length',
    return_tensors='tf',
    truncation=True
)
sample_inputs

In [None]:
tokenizer.decode(sample_inputs["input_ids"].numpy()[0])

In [None]:
sample_inputs["attention_mask"].shape

In [None]:
sample_inputs["attention_mask"]

In [None]:
sample_inputs['token_type_ids']

In [None]:
outputs = model(sample_inputs['input_ids'])
hidden_state = outputs.last_hidden_state
embedding = hidden_state[:, 0, :]
embedding

## Create Datasets for Training

In [None]:
def read_wnut(file_path):
    file_path = Path(file_path)

    raw_text = file_path.read_text().strip()
    raw_docs = re.split(r'\n\t?\n', raw_text)
    token_docs = []
    tag_docs = []
    for doc in raw_docs:
        tokens = []
        tags = []
        for line in doc.split('\n'):
            token, tag = line.split('\t')
            tokens.append(token)
            tags.append(tag)
        token_docs.append(tokens)
        tag_docs.append(tags)

    return token_docs, tag_docs

train_ner_texts, train_ner_tags = read_wnut(os.path.join(DATA_PATH, 'Event_detection/train.txt'))
test_ner_texts, test_ner_tags = read_wnut(os.path.join(DATA_PATH, 'Event_detection/dev.txt'))

event_index = 0
for event_index, tags in enumerate(train_ner_tags):
    if any(tag != 'O' for tag in tags):
        break
print(f"event found at index: {event_index}")
print(*train_ner_texts[event_index])
print(*train_ner_tags[event_index])

In [None]:
from transformers import BertTokenizerFast

def encode_tags(tags, encodings, tag2id, unk=UNK_ID):
    labels = [[tag2id[tag] for tag in doc] for doc in tags]
    encoded_labels = []
    for doc_labels, doc_offset in tqdm(zip(labels, encodings.offset_mapping), desc="encode_tags"):
        # create an empty array of -100
        doc_enc_labels = np.ones(len(doc_offset), dtype=int) * unk
        arr_offset = np.array(doc_offset)

        # set labels whose first offset position is 0 and the second is not 0
        max_len = len(doc_enc_labels[(arr_offset[:, 0] == 0) & (arr_offset[:, 1] != 0)])
        doc_enc_labels[(arr_offset[:, 0] == 0) & (arr_offset[:, 1] != 0)] = doc_labels[:max_len]
        encoded_labels.append(doc_enc_labels.tolist())

    return encoded_labels

def encode_sequence_labels(ner_tags, tag2id, num_labels=MAX_LEN):
    seq_labels = []

    for tag in ner_tags:
        tag_set = set(tag)
        current_label = np.zeros([num_labels])
        if len(tag_set) == 1:
            current_label[tag2id[OTHER]] = 1
        else:
            # here is a bias, if a seq has another event, drop all others?
            # This is 'OHE' label
            tag_set.remove(OTHER)
            for tag in tag_set:
                current_label[tag2id[tag]] = 1
        seq_labels.append(list(current_label))

    return seq_labels

def load_and_cache_dataset(train_ner_texts, train_ner_tags,
                           test_ner_texts, test_ner_tags,
                           bert_model_tok=f'{MODEL_PATH}/tokenizer',
                           max_len=MAX_LEN,
                           num_labels=NUM_LABELS):
    tokenizer = BertTokenizerFast.from_pretrained(bert_model_tok)

    tags = deepcopy(train_ner_tags)
    tags.extend(test_ner_tags)
    unique_tags = list(set(tag for doc in tags for tag in doc))
    tag2id = {tag: id for id, tag in enumerate(sorted(unique_tags))}
    id2tag = {id: tag for tag, id in tag2id.items()}

    # Tokenize and encode labels for training and testing data
    train_encodings = tokenizer(train_ner_texts,
                                is_split_into_words=True,
                                return_offsets_mapping=True,
                                padding='max_length',
                                truncation=True,
                                max_length=max_len)
    train_ner_labels = encode_tags(train_ner_tags, train_encodings, tag2id, UNK_ID)
    train_seq_labels = encode_sequence_labels(train_ner_tags, tag2id, num_labels=num_labels)

    test_encodings = tokenizer(test_ner_texts,
                               is_split_into_words=True,
                               return_offsets_mapping=True,
                               padding='max_length',
                               truncation=True,
                               max_length=max_len)
    test_ner_labels = encode_tags(test_ner_tags, test_encodings, tag2id, UNK_ID)
    test_seq_labels = encode_sequence_labels(test_ner_tags, tag2id, num_labels=num_labels)

    # offset_mapping no longer needed
    train_encodings.pop("offset_mapping")
    test_encodings.pop("offset_mapping")

    return (train_encodings, train_ner_labels,
            test_encodings, test_ner_labels,
            train_seq_labels, test_seq_labels,
            tag2id, id2tag)

(train_encodings, train_ner_labels,
 test_encodings, test_ner_labels,
 train_seq_labels, test_seq_labels,
 tag2id, id2tag) = load_and_cache_dataset(train_ner_texts,
                                            train_ner_tags,
                                            test_ner_texts,
                                            test_ner_tags)
input_ids = np.array(test_encodings['input_ids'])
attention_mask = np.array(test_encodings['attention_mask'])
token_type_ids = np.array(test_encodings['token_type_ids']) if 'token_type_ids' in train_encodings else None
ner = np.array(test_ner_labels)
seq_label = np.array(test_seq_labels)
print("input_ids shape:", input_ids.shape)
print("attention_mask shape:", attention_mask.shape)
if token_type_ids is not None:
    print("token_type_ids shape:", token_type_ids.shape)

print("ner_labels shape:", ner.shape)
print("train_seq_labels shape:", seq_label.shape)

In [None]:
id2tag

In [None]:
print(input_ids[:10].shape)
input_ids[:10]

In [None]:
print(ner[:10].shape)
ner[:10]

In [None]:
assert not np.isnan(input_ids).any()

## Imbalanced Dataset

In [None]:
from sklearn.utils.class_weight import compute_class_weight

unique, counts = np.unique(ner, return_counts=True)

# https://www.tensorflow.org/tutorials/structured_data/imbalanced_data
# https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html
weights = compute_class_weight(class_weight="balanced", classes=np.unique(ner), y=ner.flatten())
weights_dict = {i: weights[i] for i in range(len(weights))}

df_tags = pd.DataFrame({'Tag ID': unique, 'Tag': (id2tag[id] if id in id2tag else UNK for id in unique),'Count': counts})
df_tags['Weight'] = df_tags['Tag ID'].map(lambda i: weights_dict.get(i, 0.))

df_ner_weights = df_tags.sort_values(by='Tag ID', ascending=True)
df_ner_weights.loc[df_ner_weights['Tag ID'] == UNK_ID, 'Weight'] = 0. # Unkown tokens should be ignored totally.
df_ner_weights.loc[df_ner_weights['Tag ID'] == OTHER_ID, 'Weight'] = 0.05 # 'O' is over 75%! Need to reduce it within limits.

ner_weights = df_ner_weights[['Tag ID', 'Weight']].set_index('Tag ID').to_dict()['Weight']

df_ner_weights

In [None]:
class_sums = np.sum(seq_label, axis=0)
total_samples = seq_label.shape[0]
unique = np.arange(seq_label.shape[1])

class_weights = {}
for i, class_sum in enumerate(class_sums):
    if class_sum == 0:
        class_weights[i] = 0
    else:
        class_weights[i] = (total_samples / (12 * class_sum))

df_seq_weights = pd.DataFrame({
    'Tag ID': unique,
    'Tag': (id2tag[id] if id in id2tag else UNK for id in unique),
    'Count': class_sums,
    'Weight': [class_weights[i] for i in unique]
})

df_seq_weights.loc[df_seq_weights['Count'] == 0, 'Weight'] = 0
seq_weights = df_seq_weights[['Tag ID', 'Weight']].set_index('Tag ID').to_dict()['Weight']

df_seq_weights

# Model Build and Training Loop

In [None]:
LEARN_RATE=5e-5 # 5e-5
LR_FACTOR=0.1
LR_MINDELTA=1e-4
EPOCHS=100
PATIENCE=10
# TPU see: https://github.com/tensorflow/tensorflow/issues/41635
BATCH_SIZE = (8 * 1 if not is_tpu_strategy(strategy) else 4) * strategy.num_replicas_in_sync # Default 8

In [None]:
from tensorflow.keras import Model
from tensorflow.keras.optimizers import AdamW, Adam
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.losses import Loss, SparseCategoricalCrossentropy, CategoricalFocalCrossentropy, CategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.metrics import Metric, SparseCategoricalAccuracy, Precision, Recall, BinaryAccuracy
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, Callback, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.mixed_precision import LossScaleOptimizer
from tensorflow.keras.utils import register_keras_serializable

from transformers import TFBertModel, BertConfig

@register_keras_serializable(package='Custom', name='MaskedWeightedMultiClassBCE')
class MaskedWeightedMultiClassBCE(Loss):
    def __init__(self,
                 from_logits=False,
                 name='masked_weighted_multi_bce',
                 class_weight=None,
                 labels_len=MAX_LEN,
                 null_class=UNK_ID,
                 focal_gamma=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.from_logits = from_logits
        self.null_class = tf.cast(null_class, tf.float32)
        self.class_weight = class_weight
        self.labels_len = labels_len
        if class_weight is not None:
            class_weights_list = [class_weight[i] for i in sorted(class_weight)]
            self.class_weight = tf.convert_to_tensor(class_weights_list, dtype=tf.dtypes.float32)
        self.focal_gamma = focal_gamma

        # https://github.com/tensorflow/tensorflow/issues/27190 still does reduction internally!
        # self.loss_fn = BinaryCrossentropy(from_logits=self.from_logits,
        #                                   reduction=tf.keras.losses.Reduction.NONE)

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        mask = tf.logical_and(tf.greater_equal(y_true, 0), tf.less(y_true, self.labels_len - 1))
        y_true_masked = tf.where(mask, y_true, tf.zeros_like(y_true))
        y_true_masked = tf.cast(y_true_masked, tf.float32)
        # https://github.com/tensorflow/tensorflow/issues/27190 still does reduction internally!
        # loss = self.loss_fn(y_true_masked, y_pred)
        loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true_masked, logits=y_pred)

        if tf.executing_eagerly():
            tf.print("y_true shape:", tf.shape(y_true_masked))
            tf.print("y_pred shape:", tf.shape(y_pred))
            tf.print("mask shape:", tf.shape(mask))
            tf.print("loss shape:", tf.shape(loss))
            tf.print("loss:", loss)
            tf.debugging.assert_greater(tf.reduce_sum(tf.cast(mask, tf.int32)),
                                        0, message="All data are masked!")

        if self.class_weight is not None:
            loss *=  self.class_weight
            if tf.executing_eagerly():
                tf.print("class_weight shape:", tf.shape(self.class_weight))
                tf.print("loss after weights:", loss)

        mask = tf.cast(mask, tf.float32)
        loss *= mask

        sum_mask = tf.reduce_sum(mask, axis=-1)
        loss = tf.reduce_mean(loss)
        if tf.executing_eagerly():
            tf.print("sum_mask:", sum_mask)
            tf.print("loss after mask and reduc:", loss)
            tf.debugging.assert_positive(loss, message="Loss masked to zero.")

        return loss

# https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalFocalCrossentropy
@register_keras_serializable(package='Custom', name='MaskedWeightedSCCE')
class MaskedWeightedSCCE(Loss):
    def __init__(self,
                 from_logits=False,
                 name='masked_weighted_scce',
                 class_weight=None,
                 labels_len=MAX_LEN,
                 null_class=UNK_ID,
                 focal_gamma=None,
                 **kwargs):
        super().__init__(name=name, **kwargs)
        self.from_logits = from_logits
        self.null_class = tf.cast(null_class, tf.float32)
        self.class_weight = class_weight
        self.labels_len = labels_len
        if class_weight is not None:
            class_weights_list = [class_weight[i] for i in sorted(class_weight)]
            self.class_weight = tf.convert_to_tensor(class_weights_list, dtype=tf.dtypes.float32)
        self.focal_gamma = focal_gamma

        # https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy
        self.loss_fn = SparseCategoricalCrossentropy(from_logits=self.from_logits,
                                                reduction=tf.keras.losses.Reduction.NONE)

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        mask = tf.logical_and(tf.greater_equal(y_true, 0), tf.less(y_true, self.labels_len - 1))
        y_true_masked = tf.where(mask, y_true, tf.zeros_like(y_true))
        y_true_masked = tf.cast(y_true_masked, tf.float32)
        if tf.executing_eagerly():
            tf.debugging.assert_greater(tf.reduce_sum(tf.cast(mask, tf.int32)),
                                        0, message="All data are masked!")

        if self.focal_gamma is not None:
            # inspired by: https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_categorical_focal_loss.py
            loss = self.loss_fn(y_true_masked, y_pred)
            y_pred = tf.clip_by_value(y_pred, clip_value_min=-100., clip_value_max=100.)
            proba = tf.nn.softmax(y_pred)
            y_true_rank = y_true_masked.shape.rank

            p_t = tf.gather(proba, tf.cast(y_true_masked, tf.int32),
                            axis=-1, batch_dims=y_true_rank)
            focal_modulation = tf.cast((1. - tf.clip_by_value(p_t, 0.01, 0.99)) ** self.focal_gamma, tf.float32)
            loss *= focal_modulation
            if self.class_weight is not None:
                loss *= tf.gather(self.class_weight, tf.cast(y_true_masked, tf.int32))
            if tf.executing_eagerly():
                tf.debugging.assert_all_finite(focal_modulation, "Focal contains NaN or Inf")
        else:
          # We remove wieghts from focal loss as we zero the UNK class (ln(0)).
          loss = self.loss_fn(y_true_masked, y_pred,
                             sample_weight=tf.gather(self.class_weight,
                                                 tf.cast(y_true_masked, tf.int32)) if self.class_weight is not None
                                                 else None)
        loss = tf.cast(loss, tf.float32)
        loss *=  tf.cast(mask, tf.float32)
        # Avoid div by 0.
        sum_mask = tf.reduce_sum(tf.cast(mask, tf.float32))
        if tf.executing_eagerly():
            tf.debugging.assert_positive(sum_mask, message="sum_mask zeroed.")
        loss = (tf.reduce_sum(loss) / sum_mask
                      if sum_mask > 0.
                      else tf.constant(0., dtype=tf.float32))
        if tf.executing_eagerly():
            tf.debugging.assert_positive(loss, message="Loss masked to zero.")

        return loss

In [None]:
# https://www.tensorflow.org/text/tutorials/bert_glue
def create_model(bert_model,
                 config,
                 num_labels=NUM_LABELS,
                 max_len=MAX_LEN,
                 unk=UNK_ID,
                 class_weight=None,
                 strategy=strategy):
    input_ids = Input(shape=(max_len,), dtype=tf.int32, name='input_ids')
    attention_mask = Input(shape=(max_len,), dtype=tf.int32, name='attention_mask')
    token_type_ids = Input(shape=(max_len,), dtype=tf.int32, name='token_type_ids')

    bert_outputs = bert_model({"input_ids": input_ids,
                                "attention_mask": attention_mask,
                                "token_type_ids": token_type_ids},
                            return_dict=True)
    bert_sequence_output = tf.cast(bert_outputs.last_hidden_state, tf.float32)
    bert_pooled_output = tf.cast(bert_outputs.pooler_output, tf.float32)

    # Zero Logits that are paddings or special characters.
    mask = tf.cast(attention_mask, tf.float32 )
    mask = tf.expand_dims(mask, -1)
    masked_output = bert_sequence_output * mask

    ner_logits = Dropout(config.hidden_dropout_prob, name='Dropout_ner_1')(masked_output)
    ner_logits = Dense(2048, name='Dense_ner_1', kernel_initializer=GlorotUniform())(ner_logits)
    ner_logits = Dropout(config.hidden_dropout_prob, name='Dropout_ner_2')(ner_logits)
    ner_output = Dense(num_labels, name='ner_output', dtype='float32')(ner_logits)

    # combine NER predictions with entire sequence
    # NER shape is [batch_size, sequence_length, num_classes (12)].
    seq_input = tf.reshape(ner_output,
                           [tf.shape(ner_output)[0], tf.shape(ner_output)[1] * tf.shape(ner_output)[2]])
    seq_input = tf.concat([bert_pooled_output, seq_input], axis=1)

    seq_logits = Dropout(config.hidden_dropout_prob, name='Dropout_seq_1')(seq_input)
    seq_logits = Dense(2048, name='Dense_seq_1', kernel_initializer=GlorotUniform())(seq_logits)
    seq_logits = Dropout(config.hidden_dropout_prob, name='Dropout_seq_2')(seq_logits)
    # -1 as we don't classify 'O' Other - An article has 1+ events or None.
    seq_output = Dense(num_labels, name='seq_output', dtype='float32')(seq_logits)

    model = Model(inputs=[input_ids, attention_mask, token_type_ids],
                  outputs=[ner_output, seq_output])

    # https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
    optimizer = AdamW(learning_rate=LEARN_RATE, clipnorm=1.0)
    if not is_tpu_strategy(strategy):
      # TPUs already use bfloat16
      optimizer = LossScaleOptimizer(optimizer, dynamic=True)
    model.compile(optimizer=optimizer,
            loss={"ner_output": MaskedWeightedSCCE(from_logits=True, class_weight=ner_weights),
                  "seq_output": MaskedWeightedMultiClassBCE(from_logits=True, class_weight=seq_weights)},
            metrics={"ner_output": ['sparse_categorical_accuracy'],
                     "seq_output": [BinaryAccuracy(threshold=0.5)]})
    return model

def get_tf_datasets(train_encodings, test_encodings, buffer_size=10000, batch_size=BATCH_SIZE):
    def create_dataset(encodings, ner_labels, seq_labels):
        input_ids = np.array(encodings['input_ids'])
        attention_mask = np.array(encodings['attention_mask'])
        token_type_ids = np.array(encodings['token_type_ids']) if 'token_type_ids' in train_encodings else None
        ner_labels = np.array(ner_labels)
        seq_labels = np.array(seq_labels)
        return tf.data.Dataset.from_tensor_slices((
            {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids,
            },
            {
                'seq_output': seq_labels,
                'ner_output': ner_labels,
            },
        ))
    # TODD: revert to train
    train_dataset = create_dataset(train_encodings, train_ner_labels, train_seq_labels)
    train_dataset = (train_dataset.shuffle(buffer_size=buffer_size)
                                    .batch(batch_size)
                                    .cache()
                                    .prefetch(tf.data.experimental.AUTOTUNE))
    test_dataset = create_dataset(test_encodings, test_ner_labels, test_seq_labels)
    test_dataset = (test_dataset.shuffle(buffer_size=buffer_size)
                                .batch(batch_size)
                                .cache()
                                .prefetch(tf.data.experimental.AUTOTUNE))

    return train_dataset, test_dataset

with strategy.scope():
    train_dataset, test_dataset = get_tf_datasets(train_encodings, test_encodings)

    config = BertConfig.from_pretrained(MODEL_PATH)
    config.num_labels = NUM_LABELS
    bert_model = TFBertModel.from_pretrained(f'{MODEL_PATH}/model', config=config)

    model = create_model(bert_model,
                         config,
                         num_labels=len(id2tag), class_weight=class_weights)
    # https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard
    tensorboard_callback = TensorBoard(log_dir='./logs',
                                        histogram_freq=2,
                                        embeddings_freq=2)
    # https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping
    early_stopping = EarlyStopping(mode='min', patience=PATIENCE, start_from_epoch=1)

    # tf.debugging.enable_check_numerics() # - Assert if no Infs or NaNs go through. not for TPU!
    # tf.config.run_functions_eagerly(not is_tpu_strategy(strategy)) # - Easy debugging
    # https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
    history = model.fit(train_dataset,
                        epochs=EPOCHS,
                        callbacks=[tensorboard_callback, early_stopping, TerminateOnNaN()],
                        verbose="auto",
                        validation_data=test_dataset)

## Save Model

In [None]:
from tensorflow.keras.models import load_model
from tensorflow.keras.models import save_model

import zipfile

def zip_models(directory, output_filename):
    with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(directory):
            for file in files:
                zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(directory, '..')))

MODEL_SAVE_PATH = './models/bert_news'
model.save(MODEL_SAVE_PATH, save_format='tf')
custom_objects = {
    'MaskedWeightedMultiClassBCE': MaskedWeightedMultiClassBCE,
    'MaskedWeightedSCCE': MaskedWeightedSCCE
}

ZIP_MODEL=True # May be very large!
if ZIP_MODEL:
    zip_models('./models', 'models.zip')

loaded_model = load_model(MODEL_SAVE_PATH, custom_objects=custom_objects)
loaded_model.summary()

## Evaluate NER Classifier

In [None]:
traindata = train_dataset.unbatch().batch(1).take(1)

y1 = loaded_model.predict(traindata)
print(f"NER labels shape: {y1[0].shape}")
print(f"Sequence labels shape: {y1[1].shape}")

In [None]:
predicted_classes = np.argmax(y1[0], axis=-1)

print("Logits (NEW):", y1[0])
print("Predicted classes:", predicted_classes)

predicted_events = (y1[1] > 0.5).astype(int)

print("Logits (SEQ):", y1[1])
predicted_event_names = [[id2tag[i] for i, present in enumerate(article) if present == 1] for article in predicted_events]
print(f"Predicted ({len(predicted_event_names[0])}) Event(s) in article: ({', '.join(predicted_event_names[0])})")

In [None]:
inputs, labels = next(iter(traindata))
print(f"NER Labels found: {labels['ner_output']}")
print(f"Article NER sequence: {labels['seq_output']}")

article_events_tags = [
    [id2tag[idx] if event == 1 else None for idx, event in enumerate(label.numpy())]
    for label in labels['seq_output']
]
article_events_tags = [
    [tag for tag in event_tags if tag is not None]
    for event_tags in article_events_tags
]
print(f"Article Events: {article_events_tags}")

In [None]:
mask = tf.logical_and(tf.greater_equal(labels['ner_output'], 0), tf.less(labels['ner_output'], MAX_LEN))
mask = tf.cast(mask, tf.float32)
losses = tf.keras.losses.sparse_categorical_crossentropy(labels['ner_output'], y1[0], from_logits=True, ignore_class=UNK_ID)
losses *= mask
mean_loss = tf.reduce_sum(losses) / tf.reduce_sum(tf.cast(mask, tf.float32))

print(f"Mask: {mask}")
print(f"Masked Losses: {losses.numpy()}")
print(f"Mean Loss: {mean_loss.numpy()}")

binary_losses = tf.keras.losses.binary_crossentropy(labels['seq_output'] , y1[1], from_logits=True)
mean_binary_loss = tf.reduce_mean(binary_losses)
print(f"Bin Losses: {binary_losses.numpy()}")
print(f"Mean Binary Loss: {mean_binary_loss.numpy()}")

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_classification(history):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.plot(history.history['ner_output_loss'], label='Training NER')
    plt.plot(history.history['val_ner_output_loss'], label='Validation NER')
    plt.plot(history.history['seq_output_loss'], label='Training SEQ')
    plt.plot(history.history['val_seq_output_loss'], label='Validation SEQ')
    plt.title('Training vs. Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history.history['ner_output_sparse_categorical_accuracy'], label='Training NER Accuracy')
    plt.plot(history.history['val_ner_output_sparse_categorical_accuracy'], label='Validation NER Accuracy')
    plt.plot(history.history['seq_output_binary_accuracy'], label='Training SEQ Accuracy')
    plt.plot(history.history['val_seq_output_binary_accuracy'], label='Validation SEQ Accuracy')

    plt.title('Training vs. Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_classification(history)

In [None]:
from sklearn.metrics import f1_score, precision_recall_fscore_support, classification_report, roc_auc_score, hamming_loss, jaccard_score, log_loss

def print_classification_performanca(predictions, true_labels, id2tag, max_len = MAX_LEN, binary=False):
    if true_labels.ndim > 2:
        true_labels = true_labels.reshape(true_labels.shape[0], -1)
    if predictions.ndim > 2:
        predictions = predictions.reshape(predictions.shape[0], -1)

    if not binary:
        true_labels = true_labels.flatten()
        predictions = predictions.flatten()
    print(f"Shapes: {true_labels.shape} and {predictions.shape}")
    assert true_labels.shape == predictions.shape, "Shape mismatch between labels and predictions"

    print(classification_report(true_labels, predictions, labels=range(len(id2tag)), target_names=list(id2tag.values()), zero_division=0))

    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        true_labels, predictions,
        average='weighted')

    ner_correct = np.sum(predictions == true_labels)
    ner_total = len(true_labels)

    h_loss = hamming_loss(true_labels, predictions)

    print('Accuracy: {:.2f}%'.format(100. * ner_correct / ner_total))
    print('Hamming: {:.2f}%'.format(h_loss))
    print(f"Precision: {100. * weighted_precision:.2f}%, Recall: {100. * weighted_recall:.2f}%,, F1-Score: {100. * weighted_f1:.2f}%")
    if not binary:
        cm = confusion_matrix(true_labels, predictions)
        plt.figure(figsize=(10, 7))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                    xticklabels=list(id2tag.values()),
                    yticklabels=list(id2tag.values()))
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.show()

predictions = model.predict(test_dataset)
predicted_label_indices = np.argmax(predictions[0], axis=-1)
print_classification_performanca(predicted_label_indices, np.array(test_ner_labels), id2tag)

## Evaluate Sequence Classifier

In [None]:
seq_pred = predictions[1]
seq_labels = np.array(test_seq_labels)
seq_correct = 0
for pred, label in zip(seq_pred, seq_labels):
    pred = tf.nn.sigmoid(pred)
    pred_tags = set(np.where(pred > 0.5)[0])
    label_tags = set(np.where(label == 1)[0])
    if pred_tags == label_tags:
        seq_correct += 1

event_accuracy_ratio = seq_correct / len(seq_labels) if len(seq_labels) > 0 else 0
print(f"Accuracy Ratio: {event_accuracy_ratio:.2f}, correct predictions ({seq_correct} out of {len(seq_labels)})")