# ConNER: Text Preprocessing & Output Processing

This notebook prepares a pipeline that (i) preprocesses an incoming text (abstract) and updates it to the format expected by ConNER; and (ii) processed the output of the model to extract the entities.

## 1. Imports

In [216]:
## Model definition related
import os
import bs4
import numpy as np
import json
import pandas as pd
import re
import string
from pprint import pprint

## 2. Loading the data
- Expects an incoming text with only the fields of "title" and "abstract"

In [30]:
## Loads the processed CSV files from Nomita and Woojae.
## If we want to use data from PubMed database direct, we will have to build a processing pipeline for that.
train_path = "./data/OfficialTrainingSet1.csv"
test_path = "./data/OfficialTestSet1.csv"
val_path = "./data/OfficialValidationSet1.csv"

# Reading the files but only retaining the title and abstract columns
df_train = pd.read_csv(train_path)[['title', 'abstract']]
df_val = pd.read_csv(test_path)[['title', 'abstract']]
df_test = pd.read_csv(val_path)[['title', 'abstract']]

# Forming a new column with the merged texts
df_train['text'] = df_train["title"] + " " + df_train["abstract"]
df_val['text'] = df_val["title"] + " " + df_val["abstract"]
df_test['text'] = df_test["title"] + " " + df_test["abstract"]

# This will be the starting point for further preprocessing.
df_test['text'][0]

'Tricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant. A newborn with massive tricuspid regurgitation, atrial flutter, congestive heart failure, and a high serum lithium level is described. This is the first patient to initially manifest tricuspid regurgitation and atrial flutter, and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy. Sixty-three percent of these infants had tricuspid valve involvement. Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during early pregnancy. It also causes neurologic depression, cyanosis, and cardiac arrhythmia when consumed prior to delivery.'

In [223]:
def convert_text_to_ConNER_format(df, tokenizer):
    '''
    - Takes in a dataframe and returns an "example" object that can be taken by the
      "load_and_cache_examples" function from the data_utils.py function from the ConNER repo.
    - Also outputs a mapping dictionary linking the word indices to token indices
    
    Inputs:
    - df: dataframe with a "text" column that contains the paragraph combining the title and abstract of a journal
    - the same tokenizer to be used in the NER model.
    
    Output:
    - example object derived from the InputExample function
    - mapping dictionary linking words to token indices
    
    Other Prerequisites:
    - InputExample functionm, imported from the data_utils.py module
    '''
    from data_utils import InputExample
    import copy
    
    mode = "doc_dev"  ## just inherited from the ConNER codes, stands for document-based evaluation for the dev set.
    
    texts = df['text']
    guid_index = 1
    examples = []
    mapping_dict = []
      
    for text in texts:
        
        example_dict = {} # mapping info for the current example (i.e. "text") only.  
        
        # handling the words.  For handling punctuation, referenced:
        # https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
        words =  text.translate(str.maketrans({key: " {0} ".format(key) for key in string.punctuation})).split()
        labels = [0] * len(words)  ## just set labels to 0 as a dummy as we are doing inference
        
        token_index = 1   # starts from 1 because token 0 is the CLS token for Bert and Roberta
        word_index = 0    # counts from 0, as words in the Scibert model is also counted from 0
        
        for word in words:
            # doing a test tokenization to see the length
            tokenized_word = tokenizer(word)['input_ids'][1:-1]   # skips the 1st & last input ids as these are the CLS and SEP tokens
            num_of_tokens = len(tokenized_word)      # this is the no. of sub-word tokens that the current word has
        
            assert len(example_dict) == word_index, "Error in code, word index is probably wrong"
        
            # append a sub-list showing the list of corresponding token ids
            token_indices = list(range(token_index, token_index + num_of_tokens))
            example_dict[word_index] = {'word': word, 'token_idx': token_indices}
            
            word_index += 1
            token_index += num_of_tokens
            
        mapping_dict.append(example_dict)
        
        if "tags_hp" in labels:
            hp_labels = item["tags_hp"]
        else:
            hp_labels = [None]*len(labels)
                
        examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
                                     words=words,
                                     labels=labels,
                                     hp_labels=hp_labels))
        guid_index += 1
        
    return examples, mapping_dict

In [163]:
model_path = "./ConNER"
tokenizer = AutoTokenizer.from_pretrained(model_path)

test_set, word_to_token_map = convert_text_to_ConNER_format(df_test, tokenizer)
test_set



[<data_utils.InputExample at 0x21ebd18a190>,
 <data_utils.InputExample at 0x21ebd18a160>,
 <data_utils.InputExample at 0x21ebd18abe0>,
 <data_utils.InputExample at 0x21ebd18acd0>,
 <data_utils.InputExample at 0x21ebd18ae80>,
 <data_utils.InputExample at 0x21ebd18a9a0>,
 <data_utils.InputExample at 0x21ebd18a520>,
 <data_utils.InputExample at 0x21ec2beb6a0>,
 <data_utils.InputExample at 0x21ec2beb9a0>,
 <data_utils.InputExample at 0x21ec2beb880>,
 <data_utils.InputExample at 0x21ec2beb730>,
 <data_utils.InputExample at 0x21ec2beb520>,
 <data_utils.InputExample at 0x21ec2bebb20>,
 <data_utils.InputExample at 0x21ec2beb7f0>,
 <data_utils.InputExample at 0x21ec2bebd30>,
 <data_utils.InputExample at 0x21d0b8ae490>,
 <data_utils.InputExample at 0x21d0b8ae3d0>,
 <data_utils.InputExample at 0x21d0b8aee50>,
 <data_utils.InputExample at 0x21d0b8ae610>,
 <data_utils.InputExample at 0x21d0b8ae850>,
 <data_utils.InputExample at 0x21d0b8ae700>,
 <data_utils.InputExample at 0x21d0b8ae370>,
 <data_uti

In [224]:
(word_to_token_map[0])

{0: {'word': 'Tricuspid', 'token_idx': [1, 2, 3]},
 1: {'word': 'valve', 'token_idx': [4, 5]},
 2: {'word': 'regurgitation', 'token_idx': [6, 7, 8]},
 3: {'word': 'and', 'token_idx': [9]},
 4: {'word': 'lithium', 'token_idx': [10, 11]},
 5: {'word': 'carbonate', 'token_idx': [12, 13]},
 6: {'word': 'toxicity', 'token_idx': [14]},
 7: {'word': 'in', 'token_idx': [15]},
 8: {'word': 'a', 'token_idx': [16]},
 9: {'word': 'newborn', 'token_idx': [17, 18]},
 10: {'word': 'infant.', 'token_idx': [19, 20]},
 11: {'word': 'A', 'token_idx': [21]},
 12: {'word': 'newborn', 'token_idx': [22, 23]},
 13: {'word': 'with', 'token_idx': [24]},
 14: {'word': 'massive', 'token_idx': [25, 26]},
 15: {'word': 'tricuspid', 'token_idx': [27, 28, 29]},
 16: {'word': 'regurgitation,', 'token_idx': [30, 31, 32, 33]},
 17: {'word': 'atrial', 'token_idx': [34]},
 18: {'word': 'flutter,', 'token_idx': [35, 36, 37]},
 19: {'word': 'congestive', 'token_idx': [38, 39]},
 20: {'word': 'heart', 'token_idx': [40]},
 21

## 3. Converting the Dataset Format & Loading Model
- Built based on the load_and_cache_examples from the data_utils.py from the ConNER repo.

In [50]:
from transformers import BertPreTrainedModel,BertForTokenClassification, BertModel, RobertaModel, RobertaTokenizer, BertPreTrainedModel, RobertaConfig
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import torch.nn as nn
import torch.nn.functional as F
from  torch.nn.utils.rnn  import pack_padded_sequence

from torch.autograd import Variable
from torch.nn import CrossEntropyLoss, KLDivLoss

from transformers import BertConfig, RobertaConfig

In [70]:
## Eval related
import argparse
import logging
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

#Remember to copy the "data_utils.py" file from ConNER's repo
from data_utils import tag_to_id, get_chunks, get_labels, convert_examples_to_features
from flashtool import Logger
logger = logging.getLogger(__name__)

In [59]:
from ConNER_model_definition import RobertaForTokenClassification_v2

## Loading model
model_path = "./ConNER"

## It appears the checkpoint is a Roberta-based model as loading it using BERT model yields an error.
#test_model  = BERTForTokenClassification_v2.from_pretrained(model_path)

test_model = RobertaForTokenClassification_v2.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)



In [225]:
def load_and_cache_examples(args, df, tokenizer, labels, pad_token_label_id, mode,
                            entity_name='bc5cdr', remove_labels=False):
    
    examples, word_to_token_map = convert_text_to_ConNER_format(df, tokenizer)
    features = convert_examples_to_features(
        examples,
        labels,
        args.max_seq_length,
        tokenizer,
        cls_token_at_end=bool(args.model_type in ["xlnet"]),
        # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=bool(args.model_type in ["roberta"]),
        # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
        pad_on_left=bool(args.model_type in ["xlnet"]),
        # pad on the left for xlnet
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
        pad_token_label_id=pad_token_label_id,
        entity_name=entity_name,
    )

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
    all_full_label_ids = torch.tensor([f.full_label_ids for f in features], dtype=torch.long)
    all_hp_label_ids = torch.tensor([f.hp_label_ids for f in features], dtype=torch.long)
    all_entity_ids = torch.tensor([f.entity_ids for f in features], dtype=torch.long)
    if remove_labels:
        all_full_label_ids.fill_(pad_token_label_id)
        all_hp_label_ids.fill_(pad_token_label_id)
    all_ids = torch.tensor([f for f in range(len(features))], dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_full_label_ids, all_hp_label_ids, all_entity_ids, all_ids)
    
    return dataset, word_to_token_map

In [271]:
device = torch.device("cuda")
pad_token_label_id = CrossEntropyLoss().ignore_index
labels = ['O', 'B-Chemical', 'B-Disease', 'I-Chemical', 'I-Disease']

parser = argparse.ArgumentParser()
args = parser.parse_args("")

args.model_type = "roberta"
args.model_name_or_path = "./ConNER"
args.max_seq_length = 512   ## modified from 128, ## Using 768 leads to errors!
args.per_gpu_train_batch_size = 8
args.per_gpu_eval_batch_size = 8
args.n_gpu = 1
args.device = device
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
args.local_rank = -1

args.gradient_accumulation_steps = 1
args.learning_rate = 5e-5
args.weight_decay = 0.0
args.adam_epsilon = 1e-8
args.adam_beta1 = 0.9
args.adam_beta2 = 0.98
args.max_grad_norm = 1.0
args.num_train_epochs = 3.0
args.max_steps = -1
args.warmup_steps = 0
args.logging_steps = 10000
args.save_steps = 10000
args.seed = 1


eval_dataset, word_to_token_map = load_and_cache_examples(args, df_test, tokenizer, labels, pad_token_label_id, mode="doc_dev")
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

In [272]:
word_to_token_map[0]

{0: {'word': 'Tricuspid', 'token_idx': [1, 2, 3]},
 1: {'word': 'valve', 'token_idx': [4, 5]},
 2: {'word': 'regurgitation', 'token_idx': [6, 7, 8]},
 3: {'word': 'and', 'token_idx': [9]},
 4: {'word': 'lithium', 'token_idx': [10, 11]},
 5: {'word': 'carbonate', 'token_idx': [12, 13]},
 6: {'word': 'toxicity', 'token_idx': [14]},
 7: {'word': 'in', 'token_idx': [15]},
 8: {'word': 'a', 'token_idx': [16]},
 9: {'word': 'newborn', 'token_idx': [17, 18]},
 10: {'word': 'infant', 'token_idx': [19]},
 11: {'word': '.', 'token_idx': [20]},
 12: {'word': 'A', 'token_idx': [21]},
 13: {'word': 'newborn', 'token_idx': [22, 23]},
 14: {'word': 'with', 'token_idx': [24]},
 15: {'word': 'massive', 'token_idx': [25, 26]},
 16: {'word': 'tricuspid', 'token_idx': [27, 28, 29]},
 17: {'word': 'regurgitation', 'token_idx': [30, 31, 32]},
 18: {'word': ',', 'token_idx': [33]},
 19: {'word': 'atrial', 'token_idx': [34]},
 20: {'word': 'flutter', 'token_idx': [35, 36]},
 21: {'word': ',', 'token_idx': [37

## 4. Testing Out Inferencing
- Things look fine.

In [273]:
test_model.to(device)

test_model.eval()

nb_eval_steps = 0
preds = None
out_label_ids = None

for batch in tqdm(eval_dataloader, desc="Evaluating"):
    batch = tuple(t.to(args.device) for t in batch)

    with torch.no_grad():
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
        if args.model_type != "distilbert":
            inputs["token_type_ids"] = (
                batch[2] if args.model_type in ["bert", "xlnet"] else None
            )  # XLM and RoBERTa don"t use segment_ids
        outputs = test_model(**inputs)
        tmp_eval_loss, logits = outputs[:2]

        if args.n_gpu > 1:
            tmp_eval_loss = tmp_eval_loss.mean()

    nb_eval_steps += 1
    if preds is None:
        preds = logits.detach().cpu().numpy()
        out_label_ids = inputs["labels"].detach().cpu().numpy()
    else:
        preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
        out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

preds = np.argmax(preds, axis=2)

Evaluating: 100%|██████████| 63/63 [00:24<00:00,  2.58it/s]


In [274]:
## This is the result from this live run
preds

array([[0, 2, 4, ..., 0, 0, 0],
       [0, 1, 3, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 2, 4, ..., 0, 0, 0]], dtype=int64)

In [275]:
preds.shape

(500, 512)

In [276]:
print(f"Shape of the prediction numpy array: {preds[0].shape}")
preds[0]

Shape of the prediction numpy array: (512,)


array([0, 2, 4, 4, 4, 4, 4, 4, 4, 0, 1, 3, 3, 3, 2, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 2, 4, 4, 4, 4, 4, 0, 2, 4, 4, 0, 2, 4, 4, 4, 4, 0,
       0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4,
       4, 4, 4, 4, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0,
       0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4,
       4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 4, 0, 2, 4, 0, 0, 2,
       4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

## 5. Post-processing
- This takes in the preds (i.e. the model predictions), and combines it with the word_to_token_map to identify the entities

In [280]:
def extract_entity_and_word_location(preds, word_to_token_map):
    '''
    Takes in the model predictions from ConNER and the word_to_token_map to extract
    the identified chemicals and diseases'
     
    Inputs:
    - preds: model predictions
    - word_to_token_map: a list of dictionaries showing the words and their mapping to tokenized IDs
    
    Output:
    - a list of lists, showing the identfied chemical and disease entities and where to find them in the text.
    
    '''
    assert preds.shape[0] == len(word_to_token_map), "Number of samples in the predictions and mapping are different"
    
    overall_entities_identified = []
    
    for i in range(preds.shape[0]):
        sample_entities_identified = {}
        pred = preds[i]
        mapping = word_to_token_map[i]
        
        current_entity_type = None
        current_entity = None
        
        for word_idx in mapping.keys():
            word = mapping[word_idx]['word']
            tokens = mapping[word_idx]['token_idx']
            
            
            word_pred = pred[tokens[0]:tokens[-1]+1]
            
            ## cases when a starting word is found:
            if 2 in word_pred and 1 not in word_pred:
                assert 3 not in word_pred, f"Error in prediction for sample {i}, the word {word} with prdictions {word_pred} predicted as B-Disease but contains I-Chemical"
                
                # handles the case when a new entity immediately follows another entity
                if current_entity_type is not None:
                    entity_word_idx = [entity_start, entity_end]
                    entity_word_idx_RE = [entity_start, entity_end+1]
                    if current_entity in sample_entities_identified.keys():
                        sample_entities_identified[current_entity]['type'].append(current_entity_type)
                        sample_entities_identified[current_entity]['word_loc'].append(entity_word_idx)
                        sample_entities_identified[current_entity]['RE_word_loc'].append(entity_word_idx_RE)
                    else:
                        sample_entities_identified[current_entity] = {'type': [current_entity_type],
                                                                      'word_loc': [entity_word_idx],
                                                                      'RE_word_loc': [entity_word_idx_RE]}
                current_entity_type = 'disease'    
                current_entity = word
                entity_start = word_idx
                entity_end = word_idx
                
            elif 1 in word_pred and 2 not in word_pred:
                assert 4 not in word_pred, f"Error in prediction for sample {i}, the word {word} with prdictions {word_pred} predicted as B-Chemical but contains I-Disease"
                
                # handles the case when a new entity immediately follows another entity
                if current_entity_type is not None:
                    entity_word_idx = [entity_start, entity_end]
                    entity_word_idx_RE = [entity_start, entity_end+1]
                    if current_entity in sample_entities_identified.keys():
                        sample_entities_identified[current_entity]['type'].append(current_entity_type)
                        sample_entities_identified[current_entity]['word_loc'].append(entity_word_idx)
                        sample_entities_identified[current_entity]['RE_word_loc'].append(entity_word_idx_RE)
                    else:
                        sample_entities_identified[current_entity] = {'type': [current_entity_type],
                                                                      'word_loc': [entity_word_idx],
                                                                      'RE_word_loc': [entity_word_idx_RE]}
                current_entity_type = 'chemical'    
                current_entity = word
                entity_start = word_idx
                entity_end = word_idx
            
            ## cases when a middle word is found:
            elif 2 not in word_pred and 4 in word_pred:
                assert 3 not in word_pred, f"Error in prediction for sample {i}, the word {word} with prdictions {word_pred} predicted as having multiple classes"
                assert current_entity_type == 'disease' , f"Error in prediction for sample {i}, the word {word} with prdictions {word_pred} predicted as disease but follows a previous word of {current_entity_type} class"
                current_entity = current_entity + " " + word
                entity_end = word_idx
            
            elif 1 not in word_pred and 3 in word_pred:
                assert 4 not in word_pred, f"Error in prediction for sample {i}, the word {word} with prdictions {word_pred} predicted as having multiple classes"
                assert current_entity_type == 'chemical' , f"Error in prediction for sample {i}, the word {word} with prdictions {word_pred} predicted as chemical but follows a previous word of {current_entity_type} class"
                current_entity = current_entity + " " + word
                entity_end = word_idx
                
            ## cases when a middle word is found:
            elif np.mean(word_pred) == 0:
                if current_entity_type is not None:
                    entity_word_idx = [entity_start, entity_end]
                    entity_word_idx_RE = [entity_start, entity_end+1]
                    if current_entity in sample_entities_identified.keys():
                        sample_entities_identified[current_entity]['type'].append(current_entity_type)
                        sample_entities_identified[current_entity]['word_loc'].append(entity_word_idx)
                        sample_entities_identified[current_entity]['RE_word_loc'].append(entity_word_idx_RE)
                    else:
                        sample_entities_identified[current_entity] = {'type': [current_entity_type],
                                                                      'word_loc': [entity_word_idx],
                                                                      'RE_word_loc': [entity_word_idx_RE]}
                        
                current_entity_type = None
                current_entity = None

            else:
                print(f"Unexpected prediction case for word {word}, at tokens {tokens}, from sample {i}")
                
        overall_entities_identified.append(sample_entities_identified)   
        
    return overall_entities_identified

In [281]:
testing = extract_entity_and_word_location(preds, word_to_token_map)

Unexpected prediction case for word range, at tokens [512], from sample 22
Unexpected prediction case for word 84, at tokens [513], from sample 22
Unexpected prediction case for word to, at tokens [514], from sample 22
Unexpected prediction case for word 98, at tokens [515], from sample 22
Unexpected prediction case for word ), at tokens [516], from sample 22
Unexpected prediction case for word ., at tokens [517], from sample 22
Unexpected prediction case for word The, at tokens [518], from sample 22
Unexpected prediction case for word mean, at tokens [519], from sample 22
Unexpected prediction case for word BIS, at tokens [520], from sample 22
Unexpected prediction case for word score, at tokens [521], from sample 22
Unexpected prediction case for word corresponding, at tokens [522], from sample 22
Unexpected prediction case for word to, at tokens [523], from sample 22
Unexpected prediction case for word the, at tokens [524], from sample 22
Unexpected prediction case for word first, a

AssertionError: Error in prediction for sample 28, the word decrease with prdictions [0 4] predicted as disease but follows a previous word of None class

In [289]:
preds[28]

array([0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 1, 3, 3, 3, 0, 0, 0, 0, 0,
       1, 3, 3, 0, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 2, 4, 0, 0, 1, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 3,
       3, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 3, 3,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 1, 3,
       3, 3, 0, 0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 3,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [282]:
word_to_token_map[28]

{0: {'word': 'Chronic', 'token_idx': [1]},
 1: {'word': 'effects', 'token_idx': [2]},
 2: {'word': 'of', 'token_idx': [3]},
 3: {'word': 'a', 'token_idx': [4]},
 4: {'word': 'novel', 'token_idx': [5]},
 5: {'word': 'synthetic', 'token_idx': [6, 7]},
 6: {'word': 'anthracycline', 'token_idx': [8, 9]},
 7: {'word': 'derivative', 'token_idx': [10, 11]},
 8: {'word': '(', 'token_idx': [12]},
 9: {'word': 'SM', 'token_idx': [13]},
 10: {'word': '-', 'token_idx': [14]},
 11: {'word': '5887', 'token_idx': [15, 16]},
 12: {'word': ')', 'token_idx': [17]},
 13: {'word': 'on', 'token_idx': [18]},
 14: {'word': 'normal', 'token_idx': [19]},
 15: {'word': 'heart', 'token_idx': [20]},
 16: {'word': 'and', 'token_idx': [21]},
 17: {'word': 'doxorubicin', 'token_idx': [22, 23, 24]},
 18: {'word': '-', 'token_idx': [25]},
 19: {'word': 'induced', 'token_idx': [26]},
 20: {'word': 'cardiomyopathy', 'token_idx': [27, 28, 29]},
 21: {'word': 'in', 'token_idx': [30]},
 22: {'word': 'beagle', 'token_idx': 

In [257]:
s = '''
Valproic acid (VPA) was given to 24 epileptic patients who were already being treated with other antiepileptic drugs. A standardized loading dose of VPA was administered, and venous blood was sampled at 0, 1, 2, 3, and 4 hours. Ammonia (NH3) was higher in patients who, during continuous therapy, complained of drowsiness (7 patients) than in those who were symptom-free (17 patients), although VPA plasma levels were similar in both groups. By measuring VPA-induced changes of blood NH3 content, it may be possible to identify patients at higher risk of obtundation when VPA is given chronically.
'''
s.translate(str.maketrans({key: " {0} ".format(key) for key in string.punctuation})).split()


['Valproic',
 'acid',
 '(',
 'VPA',
 ')',
 'was',
 'given',
 'to',
 '24',
 'epileptic',
 'patients',
 'who',
 'were',
 'already',
 'being',
 'treated',
 'with',
 'other',
 'antiepileptic',
 'drugs',
 '.',
 'A',
 'standardized',
 'loading',
 'dose',
 'of',
 'VPA',
 'was',
 'administered',
 ',',
 'and',
 'venous',
 'blood',
 'was',
 'sampled',
 'at',
 '0',
 ',',
 '1',
 ',',
 '2',
 ',',
 '3',
 ',',
 'and',
 '4',
 'hours',
 '.',
 'Ammonia',
 '(',
 'NH3',
 ')',
 'was',
 'higher',
 'in',
 'patients',
 'who',
 ',',
 'during',
 'continuous',
 'therapy',
 ',',
 'complained',
 'of',
 'drowsiness',
 '(',
 '7',
 'patients',
 ')',
 'than',
 'in',
 'those',
 'who',
 'were',
 'symptom',
 '-',
 'free',
 '(',
 '17',
 'patients',
 ')',
 ',',
 'although',
 'VPA',
 'plasma',
 'levels',
 'were',
 'similar',
 'in',
 'both',
 'groups',
 '.',
 'By',
 'measuring',
 'VPA',
 '-',
 'induced',
 'changes',
 'of',
 'blood',
 'NH3',
 'content',
 ',',
 'it',
 'may',
 'be',
 'possible',
 'to',
 'identify',
 'patients',

## 6. Comparing Against Old Results
- Looks the same, so the current pipeline is the correct implementation

In [62]:
# Compare against the old results inherited from the previous notebook
preds  ## DO NOT RE-RUN THIS CELL.  THIS SIMPLY SHOWS THE ACTUAL RESULTS WHEN FEEDING IN THE DEFAULT DATASET

array([[0, 2, 4, ..., 0, 0, 0],
       [0, 1, 3, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 2, 4, ..., 0, 0, 0]], dtype=int64)

In [63]:
## 500 samples, with each having 512 tokens (max token length):
preds.shape

(500, 512)

In [67]:
## Let's also look at the output for the first sample:
print(f"Shape of the prediction numpy array: {preds[0].shape}")
preds[0]

Shape of the prediction numpy array: (512,)


array([0, 2, 4, 4, 4, 4, 4, 4, 4, 0, 1, 3, 3, 3, 2, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 2, 4, 4, 4, 4, 4, 0, 2, 4, 4, 0, 2, 4, 4, 4, 4, 0,
       0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4,
       4, 4, 4, 4, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0,
       0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4,
       4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 4, 0, 2, 4, 0, 0, 2,
       4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,