## Imports

In [19]:
# Transformers does not support keras 3, so using keras 2
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'

In [18]:
import numpy as np
import transformers
from transformers import AutoTokenizer, TFAutoModelForTokenClassification, pipeline
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import pickle
import matplotlib.pyplot as plt
from IPython.display import display, HTML

## Data Processing

In [5]:
# Read dataset
with open('mountain_ner_data.pkl', 'rb') as file:
    data = pickle.load(file)

In [6]:
# Chek first record
data[0]

{'text': "Hey, did you know that [Chokai] is actually a volcano? It's so cool! I've always wanted to visit [Tsukuba], maybe we should plan a trip? My friend went hiking in [the Andes] last year and said it was incredible. Have you ever seen a photo of [Hood]? It looks so dangerous. Oh, and I read that the Nile River is the longest in the world.\n",
 'text_format': 'whatsup conversation',
 'text_theme': 'mountains',
 'text_size': 'small',
 'is_lower': False}

In [7]:
# Load tokenizer for the base model
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

In [20]:
def tokens_to_labels(tokens, b_label_id = 1, i_label_id = 2):
    """
    Converts a sequence of tokens into label IDs based on the B-I-O (Begin-Inside-Outside) labeling scheme.
    
    Args:
        tokens (list of str): A list of tokens (e.g., words or subwords) to be labeled.
        b_label_id (int, optional): Label ID for the beginning of a named entity (default: 1).
        i_label_id (int, optional): Label ID for the inside of a named entity (default: 2).

    Returns:
        np.array: An array of label IDs corresponding to the input tokens.
    """
    
    labels = np.array([0]) # adding label for SOS token
    name_going = False
    
    # Searching for [ or ] in tokens.
    # The tokenizer vocabulary does not contain any non-special tokens that include [ or ] along with other symbols.
    for token in tokens:
        # if token contains close bracket, target entity is ended
        if token.endswith(']'):
            name_going = False
        # if name_going is True, token is part of the target entity
        elif name_going:
            if first_tok: # begin of the target entity
                labels = np.append(labels, b_label_id) # adding begin label
                first_tok = False
            else: # inside of the target entity
                labels = np.append(labels, i_label_id) # adding inside label
        # if token contains open bracket, target entity is started
        elif token.startswith('['):
            first_tok = True
            name_going = True
        # else add zero label
        else:
            labels = np.append(labels, 0)
            
    # Adding label for EOS token
    labels = np.append(labels, 0)
    return labels

def text_to_labels(text, seq_len = 512):
    """
    Process text to tokens and labels

    Args:
    text (str): text for processing
    max_len (int): max lenght of the sequence, rest will be padded.

    Returns:
    Tokenizer object {'input_ids': np.array([]), 'attention_mask': np.array([]), 'token_type_ids': np.array([])},
    Numpy array of labels
    """
    # Recieve labels from the text
    str_tokens = tokenizer.tokenize(text)
    labels = tokens_to_labels(str_tokens)

    # Delete brackets from the text
    text = text.replace("[", "").replace("]", "")
    
    # Chek if len labels match len tokens
    token_obj = tokenizer(text, return_attention_mask = True, return_tensors="np")
    if len(labels) != len(token_obj['input_ids'][0]):
        print(f"Lens don`t match. label len = {len(labels)} and token len = {len(token_obj['input_ids'][0])}")
        raise ValueError

    # padding 
    labels_padded = np.zeros(seq_len, dtype = np.int32)
    pad_id = seq_len if len(labels) > seq_len else len(labels)
    labels_padded[:pad_id] = labels[:pad_id]
    token_obj_padded = tokenizer(text, return_attention_mask = True, return_tensors="np", padding = 'max_length', max_length=512, truncation=True)
    
    return token_obj_padded, labels_padded

def create_tf_dataset(input_ids, attention_mask, token_type_ids, labels, batch_size):
    """
    Returns tensorflow dataset from data
    """
    dataset = tf.data.Dataset.from_tensor_slices((
        {'input_ids': input_ids, 
         'attention_mask': attention_mask,
         'token_type_ids': token_type_ids}, 
        labels))
    
    dataset = dataset.batch(batch_size) # set batch size
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # preload part of the data to speed up training

    return dataset

def data_process(data, batch_size = 1, seq_len = 512, test_part = 0.15):
    """
    Processing data to tensorflow dataset

    Args:
    data (list): list of dictionaries, that have the 'text' column
    batch_size (int): batch size
    seq_len (int): len of the target sequence
    test_part (float): part of the data for test dataset

    Returns:
    train_dataset,
    test_dataset
    """
    data_len = len(data)
    train_len = int(data_len * (1-test_part))

    # initial arrays to save features
    input_ids_array = np.zeros((data_len, seq_len), dtype = np.int32)
    attention_mask_array = np.zeros((data_len, seq_len), dtype = np.int32)
    token_type_ids_array = np.zeros((data_len, seq_len), dtype = np.int32)
    labels_array = np.zeros((data_len, seq_len), dtype = np.int32)

    # iterate for every entry in data
    for i, entry in enumerate(data):
        # recieve features from data
        tokens, labels = text_to_labels(entry['text'])
        # record features
        input_ids_array[i] = tokens['input_ids']
        attention_mask_array[i] = tokens['attention_mask']
        token_type_ids_array[i] = tokens['token_type_ids']
        labels_array[i] = labels
        
    # creating datasets
    train_ds = create_tf_dataset(input_ids_array[:train_len],
                                 attention_mask_array[:train_len],
                                 token_type_ids_array[:train_len],
                                 labels_array[:train_len],
                                 batch_size)
    
    test_ds = create_tf_dataset(input_ids_array[train_len:],
                                attention_mask_array[train_len:],
                                token_type_ids_array[train_len:],
                                labels_array[train_len:],
                                batch_size)
    
    return train_ds, test_ds

In [21]:
train_ds, test_ds = data_process(data)

# Training

In [11]:
# Load model from pretraining
# Original paper: https://huggingface.co/dslim/bert-base-NER
model = TFAutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
# Change last layer
model.classifier = tf.keras.layers.Dense(3)

All PyTorch model weights were used when initializing TFBertForTokenClassification.

All the weights of TFBertForTokenClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForTokenClassification for predictions without further training.


In [12]:
# Compiling model
model.compile(optimizer=Adam(3e-5), metrics = 'accuracy')

In [13]:
# Fiting model
history = model.fit(train_ds, validation_data=test_ds, epochs = 2) 

Epoch 1/2
Cause: for/else statement not yet supported
Cause: for/else statement not yet supported


I0000 00:00:1736273300.792010     858 service.cc:145] XLA service 0x7f90784acdd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736273300.792064     858 service.cc:153]   StreamExecutor device (0): NVIDIA GeForce GTX 1080, Compute Capability 6.1
2025-01-07 18:08:20.863387: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-01-07 18:08:21.111765: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8906
I0000 00:00:1736273301.304152     858 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/2


In [14]:
# Saving model
model.save('mount_ner_model')

INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 3), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f145090>, 140260622580768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 3), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f145090>, 140260622580768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(3,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f001690>, 140259541139616), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(3,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f001690>, 140259541139616), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 3), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f145090>, 140260622580768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 3), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f145090>, 140260622580768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(3,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f001690>, 140259541139616), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(3,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f906f001690>, 140259541139616), {}).


INFO:tensorflow:Assets written to: mount_ner_model/assets


INFO:tensorflow:Assets written to: mount_ner_model/assets


# Evaluate

In [37]:
def extract_spans(labels, tokens):
    """
    Extract spans of labeled tokens.

    Args:
        labels: List of labels for each token.
        tokens: List of tokens.

    Returns:
        list: List of (start_index, end_index) tuples representing spans.
    """
    spans = []
    current_span_start = None
    for i, label in enumerate(labels):
        if label == 1:  # Start of a span
            if current_span_start is None:
                current_span_start = len("".join(tokens[:i]).replace(" ##", ""))
        elif label != 2 and current_span_start is not None:  # End of a span (not a continuation and a span started)
            spans.append((current_span_start, len("".join(tokens[:i]).replace(" ##", ""))))
            current_span_start = None
    if current_span_start is not None: # Handle spans ending at the end of the sequence
        spans.append((current_span_start, len("".join(tokens).replace(" ##", ""))))
    return spans

def get_char_index(tokens, token_index):
    """
    Calculate the character index of a token.

    Args:
        tokens: List of tokens.
        token_index: Index of the token.

    Returns:
        int: Character index.
    """
    return len("".join(tokens[:token_index]).replace(" ##", ""))

def evaluate(model, tokenizer, sample):
    """
    Evaluate model by 1 entry from test data

    Args:
        model: tuned model
        sample: entry from data
        tokenizer: model tokenizer

    Returns:
        list: list of predicted selections.
    """
    # Get model predictions and probabilities
    pred = model.predict(sample[0])
    logits = pred.logits
    softmax = tf.nn.softmax(logits, axis=-1)
    predicted_classes = np.argmax(softmax, axis=-1)[0]
    true_classes = sample[1].numpy()[0]
    
    # Convert token IDs to tokens and remove special tokens
    tokens = tokenizer.convert_ids_to_tokens(sample[0]['input_ids'][0])
    tokens_without_special = [token for token in tokens if token not in ["[CLS]", "[SEP]", "[PAD]"]]

    # Extract predicted and true spans
    predicted_spans = extract_spans(predicted_classes[1:], tokens_without_special)
    true_spans = extract_spans(true_classes[1:], tokens_without_special)
    
    # Prepare colored text for visualization
    colored_text = ""
    processed_text = ""
    last_index = 0
    
    # Iterate through tokens and apply coloring based on predictions and ground truth
    for i, token in enumerate(tokens_without_special):
        start_index = get_char_index(tokens_without_special, i)
        end_index = get_char_index(tokens_without_special, i + 1)

        is_predicted = False
        for p_start, p_end in predicted_spans:
            if p_start <= start_index < p_end:
                is_predicted = True
                break

        is_true = False
        for t_start, t_end in true_spans:
            if t_start <= start_index < t_end:
                is_true = True
                break

        token_text = token.replace("##", "")
        
        # Apply different colors based on prediction and truth values
        if is_predicted and is_true:
            color = "green"
            colored_text += f"<span style='color: {color};'>{token_text}</span>"
        elif is_predicted and not is_true:
            color = "red"
            colored_text += f"<span style='color: {color};'>{token_text}</span>"
        elif not is_predicted and is_true:
            color = "yellow"
            colored_text += f"<span style='color: {color};'>{token_text}</span>"
        else:
            colored_text += token_text
            
        # Add space between tokens unless the next token is a subword
        if i < len(tokens_without_special) - 1 and not (tokens_without_special[i+1].startswith(('##', '.', ',', '"', '\''))):
            colored_text += " "

    display(HTML(colored_text))
    print(predicted_spans)
    return predicted_spans

In [38]:
ds = iter(test_ds)

In [39]:
print('Green if correct, red if incorect true, yellow if incorect false')
for i in range(10):
    sample = next(ds)
    evaluate(model, tokenizer, sample)

Green if correct, red if incorect true, yellow if incorect false


[(83, 88), (189, 198)]


[(43, 57), (215, 230), (267, 273), (320, 335), (368, 373)]


[(254, 257), (308, 320), (523, 542)]


[(37, 44), (154, 172), (268, 281)]


[(37, 41), (151, 164), (372, 389), (463, 471), (512, 523), (695, 700), (721, 725)]


[(156, 161), (271, 275), (493, 506), (544, 561), (866, 880), (885, 899)]


[(16, 21), (227, 232), (502, 509), (747, 761), (810, 819), (881, 895)]


[(98, 112), (265, 273), (455, 463)]


[(10, 19), (65, 76), (89, 94), (166, 177), (300, 311)]


[(151, 164), (248, 254), (446, 453), (540, 564)]
