This model works well, but from what I notice sometimes it will misrecognize things due to commas or other non-character characters (though the miss rate is not high). To fix this (on initial impressions) we should preprocess the text by removing those characters using regex, and we can find the names based on the index of the returned array of labels compared to the list of words (split by new lines and whitespace) for the names. And just accomodate for when ' will get taken out and end up with Justin's as Justins.

In [None]:
"""
Dataset References:
- CoNLL-2003: Tjong Kim Sang, Erik F., and De Meulder, Fien.
  "Introduction to the CoNLL-2003 shared task: Language-independent named entity recognition." 
  Proceedings of the Seventh Conference on Natural Language Learning at HLT-NAACL 2003.
  https://www.aclweb.org/anthology/W03-0419
- OntoNotes 5.0: Weischedel, Ralph, et al. "OntoNotes Release 5.0." LDC2013T19, Linguistic Data Consortium, 2013.
  https://aclanthology.org/W13-3516
"""

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AlbertTokenizer, AlbertModel, AdamW
from datasets import load_dataset

import numpy as np
import pandas as pd
import os
import ast
from tqdm import tqdm
import random

In [None]:
conll_dataset = load_dataset("conll2003", trust_remote_code=True)
ontonotes_dataset = load_dataset("ontonotes/conll2012_ontonotesv5", "english_v12", trust_remote_code=True)
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')

In [None]:
CONLL_2003_LABEL_MAP = {0: 0, 1: 1, 2: 1, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0}
ONTONOTES_LABEL_MAP = {0: 0, 1: 1, 2: 1, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 
                       11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 
                       21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 
                       31: 0, 32: 0, 33: 0, 34: 0, 35: 0, 36: 0}
WINDOW_SIZE = 5

In [None]:
class NameDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return (self.data[idx]["input_ids"].squeeze(0), 
                self.data[idx]["attention_mask"].squeeze(0), 
                self.data[idx]["label"])
    
def create_context_window(words, i):
    start = max(0, i - WINDOW_SIZE)
    end = min(len(words), i + WINDOW_SIZE + 1)

    beginning_pad = ["[PAD]"] * (max(0, WINDOW_SIZE - i))
    ending_pad = ["[PAD]"] * (max(0, (i + WINDOW_SIZE + 1) - len(words)))

    context_window = ["[CLS]"] + beginning_pad + words[start:i] + ["<w>"] + [words[i]] + ["</w>"] + words[i+1:end] + ending_pad + ["[SEP]"]
    return context_window

def create_training_data(dataset_split):
    data = []

    for example in tqdm(dataset_split):
        label_map = CONLL_2003_LABEL_MAP if example['dataset'] == 'conll' else ONTONOTES_LABEL_MAP
        words = example["tokens"]
        labels = example["ner_tags"]

        for i, word in enumerate(words):
            # Convert labels to binary (1 = name, 0 = not a name)
            label = label_map[labels[i]]

            context_window = create_context_window(words, i)

            # Tokenize
            encoding = tokenizer(context_window, padding="max_length", max_length=35, truncation=True, is_split_into_words=True, return_tensors="pt")

            data.append({
                "input_ids": encoding["input_ids"],
                "attention_mask": encoding["attention_mask"],
                "label": torch.tensor(label, dtype=torch.float)
            })

    return data

In [None]:
def filter_conll(dataset_split, keep_non_name_ratio=0.5):
    sentences_with_names = []
    sentences_without_names = []

    for example in dataset_split:
        words = example["tokens"]
        labels = example["ner_tags"]

        if 1 in labels or 2 in labels:
            sentences_with_names.append({"dataset": 'conll', "tokens": words, "ner_tags": labels})
        else:
            sentences_without_names.append({"dataset": 'conll', "tokens": words, "ner_tags": labels})
    
    num_non_name_sentences = int(len(sentences_with_names) * keep_non_name_ratio)
    sentences_without_names = random.sample(sentences_without_names, min(num_non_name_sentences, len(sentences_without_names)))

    filtered_dataset = sentences_with_names + sentences_without_names
    random.shuffle(filtered_dataset)
    return filtered_dataset

def filter_ontonotes(dataset_split, keep_non_name_ratio=0.5):
    sentences_with_names = []
    sentences_without_names = []

    for document in dataset_split:
        for sentence in document["sentences"]:
            words = sentence["words"]
            labels = sentence["named_entities"]

            if 1 in labels or 2 in labels:
                sentences_with_names.append({"dataset": 'ontonotes', "tokens": words, "ner_tags": labels})
            else:
                sentences_without_names.append({"dataset": 'ontonotes', "tokens": words, "ner_tags": labels})
    
    num_non_name_sentences = int(len(sentences_with_names) * keep_non_name_ratio)
    sentences_without_names = random.sample(sentences_without_names, min(num_non_name_sentences, len(sentences_without_names)))

    filtered_dataset = sentences_with_names + sentences_without_names
    random.shuffle(filtered_dataset)
    return filtered_dataset

def filter_dataset(dataset_split, keep_non_name_ratio=0.5, label_map='conll'):
    if label_map == 'conll':
        return filter_conll(dataset_split, keep_non_name_ratio)
    elif label_map == 'ontonotes':
        return filter_ontonotes(dataset_split, keep_non_name_ratio)
    else:
        raise ValueError("Invalid label_map")

In [None]:
print("Creating training data")
filtered_conll_train = filter_dataset(conll_dataset["train"], keep_non_name_ratio=0.2, label_map='conll')
filtered_ontonotes_train = filter_dataset(ontonotes_dataset["train"], keep_non_name_ratio=0.2, label_map='ontonotes')
train_data = create_training_data(filtered_conll_train[:3000]) + create_training_data(filtered_ontonotes_train[:3000])
print("Creating validation data")
filtered_conll_val = filter_dataset(conll_dataset["validation"], keep_non_name_ratio=0.2, label_map='conll')
filtered_ontonotes_val = filter_dataset(ontonotes_dataset["validation"], keep_non_name_ratio=0.2, label_map='ontonotes')
val_data = create_training_data(filtered_conll_val) + create_training_data(filtered_ontonotes_val)
print("Creating test data")
filtered_conll_dataset = filter_dataset(conll_dataset["test"], keep_non_name_ratio=0.2, label_map='conll')
filtered_ontonotes_dataset = filter_dataset(ontonotes_dataset["test"], keep_non_name_ratio=0.2, label_map='ontonotes')
test_data = create_training_data(filtered_conll_dataset) + create_training_data(filtered_ontonotes_dataset)

train_dataset = NameDataset(train_data)
val_dataset = NameDataset(val_data)
test_dataset = NameDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class NameClassifier(nn.Module):
    def __init__(self, bert_model_name="albert-base-v2"):
        super(NameClassifier, self).__init__()
        self.bert = AlbertModel.from_pretrained(bert_model_name)
        self.fc = nn.Linear(self.bert.config.hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.fc(cls_output)
        return self.sigmoid(logits)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NameClassifier().to(device)
criterion = nn.BCELoss()
optimizer = AdamW(model.parameters(), lr=1e-5)

In [None]:
EPOCHS = 50
PATIENCE = 3

best_val_loss = np.inf
patience_counter = 0
best_model_state = None

for epoch in range(EPOCHS):
    print(f"Training Epoch {epoch+1}")
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs.squeeze(), labels.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Training Loss: {avg_loss}")

    print(f"Validation Epoch {epoch+1}")
    model.eval()
    total_loss = 0
    total_correct = 0

    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs.squeeze(), labels.float())
            total_loss += loss.item()

            preds = torch.round(outputs.squeeze())
            total_correct += (preds == labels).sum().item()
        
        avg_loss = total_loss / len(val_loader)
        accuracy = total_correct / len(val_dataset)
        print(f"Epoch {epoch+1}, Validation Loss: {avg_loss}, Validation Accuracy: {accuracy}")
    
    if avg_loss < best_val_loss:
        best_val_loss = avg_loss
        patience_counter = 0
        best_model_state = model.state_dict()
        print("New best model found")
    else:
        patience_counter += 1

    if patience_counter >= PATIENCE:
        print("Early stopping")
        break

In [None]:
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model")

In [None]:
print("Testing on testing set")
model.eval()
total_loss = 0
total_correct = 0
for batch in tqdm(test_loader):
    input_ids, attention_mask, labels = batch
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

    outputs = model(input_ids, attention_mask)
    loss = criterion(outputs.squeeze(), labels.float())
    total_loss += loss.item()

    predictions = (outputs > 0.5).float()
    total_correct += (predictions == labels).sum().item()

avg_loss = total_loss / len(test_loader)
accuracy = total_correct / len(test_dataset)
print(f"Loss: {avg_loss}, Accuracy: {accuracy}")

In [None]:
test_string = """
Matthew and Chloe ran into William at the museum last weekend. They all decided to explore a new photography exhibit together. On their way out, they saw Isabella and Daniel, who invited them to a rooftop dinner later that evening.
"""
import re

test_string = re.sub(r'[^\w\s]', '', test_string)
test_tokens = test_string.split()

test_data = create_training_data([{"dataset": 'conll', "tokens": test_tokens, "ner_tags": [0] * len(test_tokens)}])
test_dataset = NameDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
prediction = []
for batch in test_loader:
    input_ids, attention_mask, labels = batch
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

    outputs = model(input_ids, attention_mask)
    predictions = (outputs > 0.5).float()
    prediction.extend(predictions.squeeze().tolist())

for token, pred in zip(test_tokens, prediction):
    print(f"{token}\t\t{pred}")

In [None]:
torch.save(model.state_dict(), "name_classifier_2.pth")

In [None]:
parent_dir = os.path.dirname(os.getcwd())
txt_dir = os.path.join(parent_dir, "Regexs", "data", "ehr JMS.txt")
full_text = str()
with open(txt_dir, "r") as f:
    full_text = f.read()

In [None]:
full_text = full_text.replace("\n", " ")
full_text = re.sub(r'[^\w\s]', '', full_text)
full_text = full_text.split()

In [None]:
def predict_name(model, text):
    data = create_training_data([{"dataset": 'conll', "tokens": text, "ner_tags": [0] * len(text)}])
    dataset = NameDataset(data)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    prediction = []
    with torch.no_grad():
        for batch in loader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            outputs = model(input_ids, attention_mask)
            predictions = (outputs > 0.5).float()
            prediction.extend(predictions.squeeze().tolist())
    
    return prediction

In [None]:
predictions = predict_name(model, full_text)
for token, pred in zip(full_text, predictions):
    print(f"{token}\t\t{pred}")