In [1]:
import nltk
# required to download at least once
# nltk.download('punkt')
# nltk.download('punkt_tab')

import re

import numpy as np
import torch
from transformers import (
    BertForTokenClassification,
    BertConfig, 
    AutoTokenizer, 
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer
)

from datasets import load_dataset

import evaluate

SEQEVAL = evaluate.load('seqeval')

TRAIN_SAMPLES = 10000
EVAL_SAMPLES = 1000
SEED = 42

LABELS = ['middle-of-token', 'end-of-token']

MAX_SEQ_LEN = 512 # this includes the EOS token


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def nltk_tokenize(text):
    """
    Tokenizes the text using nltk's word tokenizer (with some modification for quotes), 
    except for whitespace, where each single space, tab, etc. is treated as it's own token
    """

    # split on any whitespace character, also keeping the whitespace characters as tokens 
    tokens = re.split(r'(\s+)', text)

    # at this point though, some whitespace tokens contain multiple characters e.g. '  ', but I only want 1 char/whitespace token like ' ', ' '
    new_tokens = []
    for token in tokens:
        if token.isspace():
            for char in token:
                new_tokens.append(char)
        else:
            new_tokens.append(token)
    tokens = new_tokens

    # now tokenize each non-whitespace token using nltk's word tokenizer
    final_tokens = []
    for token in tokens:
        if token.isspace():
            final_tokens.append(token)
        else:
            final_tokens.extend(nltk.word_tokenize(token))

    # nltk also has an annoying 'feature' where it converts double quotes to either `` or '' in a destructive manner, but I can't have that
    # so I need to go through all of the tokens, check if it *should* be double quotes (and isn't) and update the tokens if that is the case
    for i, token in enumerate(final_tokens):
        if token in ['``', "''"] and (i == 0 or final_tokens[i-1] != '"') and (i == len(final_tokens) - 1 or final_tokens[i+1] != '"'):
            final_tokens[i] = '"'

    return final_tokens


In [3]:
def label_text(text: str) -> list[str]:
    """This is like function above, but it uses a different thing under the hood b/c the default tokenizer has a lot of problems for this application"""
    spans = nltk.tokenize.NLTKWordTokenizer().span_tokenize(text)

    labels = np.ones((len(text),), dtype=np.long) * -1 # unlabelled stuff will be -1 for now, will set to 1 later

    for span in spans:
        labels[span[0]:span[1] - 1] = 0 # set everything within the span to 0 
        labels[span[1] - 1] = 1 # then set the end of the span properly (the upper range of the span is exclusive, hence -1)

    labels[labels == -1] = 1 # anything that isn't in a span is probably whitespace, so those can all be labelled 

    return labels

def create_bert_model(vocab_size=256, hidden_size=384, num_hidden_layers=6, num_attention_heads=12, intermediate_size=1536):    
    config = BertConfig(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers,
        num_attention_heads=num_attention_heads,
        intermediate_size=intermediate_size,
        pad_token_id=26, # in ascii this is the SUB character, which I will use as padding
        num_labels=2,  # the only two choices are not end of nltk token and end of nltk token
    )

    model = BertForTokenClassification(config)

    return model


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = SEQEVAL.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


In [4]:
byte_tokenize = AutoTokenizer.from_pretrained('google/byt5-small', clean_up_tokenization_spaces=False)
byte_tokenize.pad_token_id = 26
byte_tokenize.eos_token_id = 3  # in ascii this is the ETX character, which I will use as the end-of-sequence token

In [5]:
tiny_textbooks = load_dataset('nampdn-ai/tiny-textbooks')

tiny_textbooks = tiny_textbooks.shuffle(seed=42)
tiny_textbooks['train'] = tiny_textbooks['train'].select(range(TRAIN_SAMPLES))
tiny_textbooks['test'] = tiny_textbooks['test'].select(range(EVAL_SAMPLES))

# since we label a word/token by it's last character, we have to be careful how we tokenize the text so as not to loose 
# a label if the word gets truncated halfway through, which is why we truncate all texts ourselves to MAX_SEQ_LEN-1 (the -1 accounts for the EOS token)
tiny_textbooks = tiny_textbooks.map(
    lambda examples: {'text': examples['text'][:MAX_SEQ_LEN-1]} 
).map(
    lambda examples: {'labels': label_text(examples['text'])}
).map(
    lambda examples: byte_tokenize(examples['text'], truncation=True, max_length=MAX_SEQ_LEN, padding='do_not_pad'), batched=True, # the collator will take care of padding
    remove_columns=['text', 'source', 's', 'len', 'idx', 'textbook']
)

collator = DataCollatorForTokenClassification(tokenizer=byte_tokenize, padding=True, max_length=MAX_SEQ_LEN)


In [6]:
model = create_bert_model()

print(model.num_labels)


2


In [7]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=128,
    per_device_eval_batch_size=64,
    learning_rate=1e-4,
    num_train_epochs=2,
    seed=SEED,
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='epoch',
    bf16=True,
    lr_scheduler_type='cosine',
    report_to='wandb',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tiny_textbooks['train'],
    eval_dataset=tiny_textbooks['test'],
    tokenizer=byte_tokenize,
    data_collator=collator,
    compute_metrics=compute_metrics,
)




In [8]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjvp15[0m. Use [1m`wandb login --relogin`[0m to force relogin


  attn_output = torch.nn.functional.scaled_dot_product_attention(
 50%|█████     | 79/158 [00:23<00:21,  3.66it/s]

{'loss': 0.5384, 'grad_norm': 0.25033038854599, 'learning_rate': 5e-05, 'epoch': 1.0}


                                                
 50%|█████     | 79/158 [00:32<00:21,  3.66it/s]

{'eval_loss': 0.5161757469177246, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_accuracy': 0.7711452369915983, 'eval_runtime': 8.7463, 'eval_samples_per_second': 114.335, 'eval_steps_per_second': 1.829, 'epoch': 1.0}




{'loss': 0.5168, 'grad_norm': 0.2454090267419815, 'learning_rate': 0.0, 'epoch': 2.0}


                                                 
100%|██████████| 158/158 [01:19<00:00,  3.61it/s]

{'eval_loss': 0.5138279795646667, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_accuracy': 0.7714181370158784, 'eval_runtime': 13.2682, 'eval_samples_per_second': 75.368, 'eval_steps_per_second': 1.206, 'epoch': 2.0}


100%|██████████| 158/158 [01:19<00:00,  1.99it/s]

{'train_runtime': 80.3034, 'train_samples_per_second': 249.056, 'train_steps_per_second': 1.968, 'train_loss': 0.52763685395446, 'epoch': 2.0}





TrainOutput(global_step=158, training_loss=0.52763685395446, metrics={'train_runtime': 80.3034, 'train_samples_per_second': 249.056, 'train_steps_per_second': 1.968, 'total_flos': 654232903680000.0, 'train_loss': 0.52763685395446, 'epoch': 2.0})

In [11]:
def inference(model, text):
    tokenized = byte_tokenize(text, truncation=True, max_length=MAX_SEQ_LEN, return_tensors="pt").to('cuda')
    with torch.no_grad():
        outputs = model(**tokenized)
    
    predictions = torch.argmax(outputs.logits, dim=-1)
    characters = byte_tokenize.convert_ids_to_tokens(tokenized.input_ids[0])
    
    result = []
    token = ''
    for character, pred in zip(characters, predictions[0]):
        token += character
        if pred == 1:  # End of token
            result.append(token)
            token = ''
            
    print(predictions)
    return result


In [14]:
test_sentence = "This is a test sentence.  It's got some punctuation and whitespace."

test_sentence = """Following up on the post Kirkendall wrote last night about T.J.'s interview with ESPN 950 in Philadelphia, we now have the audio. One of the things that stood out to me was Houshmandzadeh's inclination about his performance against the Eagles during the infamous tie game could be inferred as a "job interview.". In the words of Goose from Top Gun, "tag him now or lose him forever.". [Editor's update: Below is an integrated player for the above interview in Philadelphia, thanks to Jason at Bleeding Green Nation]. Never miss Bengals news!"""
print(label_text(test_sentence))
print(nltk_tokenize(test_sentence))
print(inference(trainer.model, test_sentence))

[0 0 0 0 0 0 0 0 1 1 0 1 1 0 1 1 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 0
 0 0 0 1 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0
 1 1 0 0 0 1 1 0 0 0 1 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0
 0 1 1 0 0 0 1 1 0 0 1 1 0 0 0 0 0 1 1 0 0 1 1 0 1 1 0 0 1 1 0 0 0 0 0 1 1
 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 1 1 0 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0
 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0
 1 1 0 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 0
 0 0 1 1 0 0 1 1 0 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 1 1 0 1 1 1 1 1
 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0
 1 1 0 0 0 1 1 0 0 1 1 0 0 1 1 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 1 1 0 0 0 1 1
 0 0 1 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1 0 1 1 0 0 0 0 0 1 1 1 0 0 0 0
 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 1 0 0 1 1 0 0 0
 0 1 1 0 0 0 0 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 1 1
 0 1 1 0 0 0 0 1 1 0 1 1 