## Training / Fine-tuning Process

In [None]:
task = "ner"
model_checkpoint = "bert-base-multilingual-cased" # mBERT pre-trained from HuggingFace Hub
batch_size = 16

### Loading the dataset

In [None]:
from datasets import load_dataset

datasets = load_dataset("conll2003")
# datasets = load_dataset("conll2002", 'nl')

In [None]:
label_list = datasets["train"].features[f"{task}_tags"].feature.names
label_list

### Processing the data

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
label_all_tokens = True

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

### Fine-tuning

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

In [None]:
args = TrainingArguments(
    output_dir=f"test-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
)

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
metric = load_metric("seqeval")

In [None]:
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
        "LOC-f1": results['LOC']["f1"],
        "LOC-precision": results['LOC']["precision"],
        "LOC-recall": results['LOC']["recall"],
    }

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# Add early stopping to trainer

from transformers import EarlyStoppingCallback

trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3))

In [None]:
trainer.train()

### Save best model to disk

```python
model_dir = 'ner-multilingual-bert-fine-tuned'

model.save_pretrained(model_dir)

tokenizer.save_pretrained(model_dir)
```

In [None]:
# Evaluate using trainer.evaluate method
trainer.evaluate(tokenized_datasets['test'])

In [None]:
# Evaluate using trainer.predict method
predictions, labels, _ = trainer.predict(tokenized_datasets["validation"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

## Loading Fine-tuned model and predictions for non-labelled dataset

In [1]:
model_path = 'ner-multilingual-bert-fine-tuned'
label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

In [2]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(model_path, num_labels=len(label_list))

### Loading the dataset

In [3]:
import os
import re
import xml.etree.ElementTree as et

In [4]:
# Get file path LGL dataset
file_path = '../../data/TR-News/TR-News.xml'

# Load the data
tree = et.parse(file_path)
root = tree.getroot()

# Grab example title
example = root[0][0].text
example

"Policeman shot dead after assassinating Russian ambassador to Turkey, shouting 'Don't forget Aleppo!'\n        "

In [5]:
all_ground_truth = []

for article in root:
    
    gold_truth = {'text': re.sub(' +', ' ', article.find('text').text.replace('\n', ' ')),
                  'entities': sorted([{'text': top.find('phrase').text,
                                'start_pos': int(top.find('start').text),
                                'end_pos': int(top.find('end').text)} for top in article.findall('toponyms/toponym')
#                                  if top.find('gaztag/lat') != None and top.find('gaztag/lon') != None
                                     ], key=lambda k: k['start_pos'])}
    
    
    all_ground_truth.append(gold_truth)


### Processing the data

In [6]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [7]:
# Format dataset into Hugginface Dataset structure
from datasets import Dataset

list_input_data = [i['text'] for i in all_ground_truth]

TRN = Dataset.from_dict({'tokens': list_input_data})

In [8]:
TRN = TRN.map(tokenizer, input_columns='tokens', batched=True, fn_kwargs={'truncation': True})

TRN

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['attention_mask', 'input_ids', 'token_type_ids', 'tokens'],
    num_rows: 118
})

#### Fix weird characters in dataset

In [None]:
# articles = []

# for art_idx, article in enumerate(TRN['tokens']):
    
#     check = tokenizer.tokenize(article)
    
    
#     if '[UNK]' in check:
        
#         print('article', art_idx)
        
#         locations = []
        
#         for idx, token in enumerate(check):
#             if token == '[UNK]':
                
#                 print(check[idx-3:idx+3])
                
#                 locations.append(idx)
        
#         articles.append((art_idx, locations))
        

# articles

### Prepare evaluation trainer

In [9]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

In [10]:
from transformers import Trainer

test_trainer = Trainer(model, 
                       data_collator=data_collator,
#                        compute_metrics=compute_metrics
                      )

In [11]:
raw_pred, _, _ = test_trainer.predict(TRN)

In [12]:
import numpy as np

predictions = np.argmax(raw_pred, axis=2)

In [13]:
predictions[0]

array([0, 0, 7, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5,
       0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0,
       0, 0, 5, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 5, 0, 0, 7, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0,
       0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 1, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 5, 0, 0, 0, 0, 0, 1, 1, 2,
       0, 0, 7, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 7, 0, 0, 0, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [14]:
def convert_predictions(prediction_label_ids, tokenized_input_ids):
    
    prediction_labels = [[label_list[p] for t, p in zip(tokens, pred) if t != 101 and t != 102] for tokens, pred in zip(tokenized_input_ids, prediction_label_ids)]
    tokens = [tokenizer.convert_ids_to_tokens(i, skip_special_tokens=True) for i in tokenized_input_ids]

    predictions = []

    for token_set, label_set in zip(tokens, prediction_labels):

        text = tokenizer.convert_tokens_to_string(token_set)

        pred = {'text': text, 'entities': []}

        adjust_start_pos = 0

        for idx in range(len(token_set)):
            if label_set[idx] == 'B-LOC' or label_set[idx] == 'I-LOC':


                if idx == len(label_set)-1:
                    pass

                # Case 1: B-LOC followed by I-LOC --> CONTINUE
                elif label_set[idx+1] == 'I-LOC':
                    adjust_start_pos += 1
                    continue

                # Case 2: B-LOC followed by other B-LOC (together) --> CONTINUE    
                elif label_set[idx+1] == 'B-LOC' and '#' in token_set[idx+1]:
                    adjust_start_pos += 1
                    continue

                current_pos = idx            
                toponym_tokens = tokenizer.convert_tokens_to_string(token_set[current_pos-adjust_start_pos:current_pos+1])
                sub_sentence = tokenizer.convert_tokens_to_string(token_set[:current_pos+1])
                end = len(sub_sentence)
                start = end - len(toponym_tokens)

                pred['entities'].append({'text': toponym_tokens, 'start_pos': start, 'end_pos': end})

                adjust_start_pos = 0


        predictions.append(pred)
        
    return predictions 

In [25]:
def process_pred_results(pred_results, original_text_inputs):
    
    final_results = [align_pred_and_original_text(pred_result, original_text) 
                     for pred_result, original_text in zip(pred_results, original_text_inputs)]
    
    
    return final_results

def align_pred_and_original_text(pred_result, original_text):
    
    pred_text = pred_result['text']
    
    idx = 0
    removed_indices = []
    add_indices = []

    while pred_text != original_text:
        
        char_post, char_original = pred_text[idx], original_text[idx]
        
        if char_post != char_original:
            
            if char_original == ' ':
                pred_text = pred_text[:idx] + ' ' + pred_text[idx:]
                
                add_indices.append(idx)
                
                continue
            
            pred_text = pred_text[:idx] + pred_text[idx+1:]

            removed_indices.append(idx)
            
            continue

        idx += 1

        if idx > len(pred_text) - 1:
            break
    
    
    pred_entities = pred_result['entities']
    
    
    for entity in pred_entities:
        for index in removed_indices:
            if index > entity['start_pos']:
                break
            else:  
                entity['start_pos'] -= 1
                entity['end_pos'] -= 1

    
    
    for entity in pred_entities:
        for index in add_indices:
            if index > entity['start_pos']:
                break
            else:  
                entity['start_pos'] += 1
                entity['end_pos'] += 1


    

    return {'text': pred_text, 'entities': pred_entities}
    

In [47]:
results = convert_predictions(predictions, TRN['input_ids'])

# Filter out results mistaken erronous results
results = [{'text': result['text'], 'entities': [entity for entity in result['entities'] if '#' not in entity['text']]} for result in results]

processed_results = process_pred_results(results, TRN['tokens'])

In [46]:
# # Check mistakes

# mistakes = []

# for idx, test in enumerate(testing):
    
#     for entity in test['entities']:
        
#         if test['text'][entity['start_pos']:entity['end_pos']] != entity['text']:
#             print(idx)
#             mistakes.append(idx)
            
#             print('position', entity['start_pos'])
#             print('entity:   ', entity['text'])
#             print('location: ',test['text'][entity['start_pos']:entity['end_pos']])
        
        

In [48]:
processed_results

[{'text': 'A Turkish policeman fatally shot Russia\'s ambassador to Turkey on Monday in front of a shocked gathering at a photo exhibit and then, pacing near the body of his victim, appeared to condemn Russia\'s military role in Syria, shouting: "Don\'t forget Aleppo! Don\'t forget Syria!" The leaders of Turkey and Russia said the attack in Ankara, the Turkish capital, was an attempt to disrupt efforts to repair ties between their countries, which have backed opposing sides in the Syrian civil war. An Associated Press photographer and others at the art gallery watched in horror as the gunman, who was wearing a dark suit and tie, fired at least eight shots, at one point walking around Ambassador Andrei Karlov as he lay motionless and shooting him again at close range. The assailant, who was identified as Mevlut Mert Altintas, a 22-year-old member of Ankara\'s riot police squad, was later killed in a shootout with police. Three other people were wounded in the attack, authorities said. T

In [49]:
import copy

def calc_precision(tp, fp):
    return tp/(tp + fp)

def calc_recall(tp, fn):
    return tp/(tp + fn)

def calc_fscore(precision, recall):
    return 2 * (precision * recall) / (precision + recall)

def evaluate(gold_truth_labels, predictions):
    # Counts of true positives, false positives & false negatives
    tp, fp, fn = 0, 0, 0
    
    # List with false positives and false negatives
    fps, fns = [], []
    
    for gold, pred in zip(gold_truth_labels, predictions):
        
        tp_tmp, fp_tmp, fn_tmp, fns_temp, fps_temp  = evaluate_one_article(gold, pred)
        
        tp += tp_tmp
        fp += fp_tmp
        fn += fn_tmp
        
        fns.extend(fns_temp)
        fps.extend(fps_temp) 
        
    precision = calc_precision(tp, fp)
    recall = calc_recall(tp, fn)
    f_score = calc_fscore(precision, recall)    
    
    print(f'fp: {fp} | tp: {tp} | fn: {fn}')
    print(f'precision: {precision:.3f} | recall: {recall:.3f} | f-score: {f_score:.3f}')
    
    return fps, fns  
    

def evaluate_one_article(gold_truth, prediction):
    
    gold = gold_truth['entities'].copy()
    pred = prediction['entities'].copy()
    
    # Counts of true positives, false positives & false negatives
    tp, fp, fn = 0, 0, 0
    
    # List with false positives and false negatives
    fps, fns = [], []
    
    
    i = 0
    
    while len(gold) > 0 and len(pred) > 0:
        i += 1

        # Check if the first two elements are the same
        if gold[0] == pred[0]:
            tp += 1
            gold.pop(0)
            pred.pop(0)
        
        else:
            # Grab the first appearing element
            element, source = (gold[0], 'gold') if gold[0]['start_pos'] < pred[0]['start_pos'] else (pred[0], 'pred')
            
            # Remove the element first appearing element
            if source == 'gold':
                fn += 1
                fns.append(element['text'])
                gold.remove(element)
            elif source == 'pred':
                fp += 1
                fps.append(element['text'])
                pred.remove(element)
    
    if len(gold) > 0:
        fn += 1
    elif len(pred) > 0:
        fp += 1
        
    return tp, fp, fn, fns, fps   

In [50]:
def load_file(file_path):
    """
    Loads file and returns all the articles
    """
    # Load the data
    tree = et.parse(file_path)
    root = tree.getroot()

    return root

def process_article(article, filtered, file_path):
    """
    Takes article and process into desired structure
    """
    if 'GeoWebNews' in file_path:
        if filtered:
            return {'text': article.find('text').text, 
                    'entities': sorted([{'text': top.find('extractedName').text, 
                                                      'start_pos': int(top.find('start').text), 
                                                      'end_pos': int(top.find('end').text)} for top in article.findall('toponyms/toponym') 
                                                     if top.find('latitude') != None and top.find('longitude') != None], key=lambda k: k['start_pos'])}
        
        else:
            return {'text': article.find('text').text, 
                    'entities': sorted([{'text': top.find('extractedName').text, 
                                                      'start_pos': int(top.find('start').text), 
                                                      'end_pos': int(top.find('end').text)} for top in article.findall('toponyms/toponym')], key=lambda k: k['start_pos'])}
    
    
    elif not filtered:
        return {'text': article.find('text').text,
                'entities': sorted([{'text': top.find('phrase').text,
                            'start_pos': int(top.find('start').text),
                            'end_pos': int(top.find('end').text)} for top in article.findall('toponyms/toponym')
                                 ], key=lambda k: k['start_pos'])}
        
    else:
        return {'text': article.find('text').text,
                'entities': sorted([{'text': top.find('phrase').text,
                            'start_pos': int(top.find('start').text),
                            'end_pos': int(top.find('end').text)} for top in article.findall('toponyms/toponym')
                             if top.find('gaztag/lat') != None and top.find('gaztag/lon') != None
                                 ], key=lambda k: k['start_pos'])}

def process_articles(root, filtered, file_path):
    """
    Takes articles and processes them into desired structure
    """
    data = []
    
    for article in root:
        
        data.append(process_article(article, filtered, file_path))
    
    return data

def prepare_data(file_path, filtered):
    
    root = load_file(file_path)
    
    data = process_articles(root, filtered, file_path)
    
    return data

In [None]:
processed_results

In [51]:
# Get file path LGL dataset
file_path = '../../data/TR-News/TR-News.xml'

data_all_toponyms = prepare_data(file_path, filtered=False)
data_filtered_toponyms = prepare_data(file_path, filtered=True)

In [52]:
# filtered toponyms
fps, fns = evaluate(data_filtered_toponyms, processed_results)

fp: 537 | tp: 518 | fn: 495
precision: 0.491 | recall: 0.511 | f-score: 0.501


In [53]:
fps

['White House',
 'New York',
 'Queens',
 'New York City',
 'U',
 '.',
 'S',
 '.',
 'Rose Garden',
 'Southern Poverty Law Center',
 'Ronald Reagan Building',
 'U',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'Islamic State',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'Southern',
 'United States',
 'Texas',
 'US',
 'Wisconsin',
 'Wisconsin',
 'Wisconsin',
 'Wisconsin',
 'Pennsylvania',
 'Michigan',
 'Washington',
 'South Carolina',
 'US',
 'Manhattan',
 'South Carolina',
 'Florida',
 'Texas',
 'US',
 'Cuba',
 'Cuba',
 'US',
 'Cuba',
 'America',
 'Michigan',
 'Pennsylvania',
 'Wisconsin',
 'US',
 'US',
 'Russia',
 'US',
 'Wisconsin',
 'Michigan',
 'Pennsylvania',
 'Redding',
 'California',
 'Wooster St.',
 'Cumberland Farms',
 'South Main St',
 'Torrington',
 'Calgary',
 'Calgary',
 '-',
 'North West',
 'Alberta',
 'Kremlin',
 'Kremlin',
 'Canada',
 'Vancouver',
 'Downtown Eastside',
 'US',
 'London City',
 'Heathrow',
 'Gatwick',

In [54]:
fns

['Turkish',
 'Turkish',
 'Syrian',
 'Syrian',
 'U.S.',
 'Turkish',
 'Turkish',
 'Russian',
 'Granville County',
 'New York',
 'Queens',
 'WASHINGTON',
 'U.S.',
 'Texas',
 'Texas',
 'Texas',
 'U.S.',
 'U.S.',
 'U.S.',
 'U.S.',
 'U.S.',
 'Xavier University',
 'British',
 'U.S.',
 'European',
 'U.S.',
 'U.S.',
 'U.S.',
 'DETROIT',
 'U.S.',
 'U.S.',
 'US',
 'Wisconsin',
 'Wisconsin',
 'Wisconsin',
 'New York',
 'Michigan',
 'Washington',
 'South Carolina',
 'US',
 'Manhattan',
 'US',
 'Cuba',
 'Cuba',
 'Cuban',
 'US',
 'Cuba',
 'Michigan',
 'Pennsylvania',
 'Wisconsin',
 'US',
 'US',
 'Russia',
 'US',
 'BANTAM',
 'Bantam',
 'New Milford',
 'TORRINGTON',
 'Torrington',
 'Calgary',
 'Calgary',
 'Russian',
 'Russian',
 'Russian',
 'Russia',
 'Russian',
 'Canada',
 'Vancouver',
 'London',
 'Heathrow',
 'Gatwick',
 'England',
 'Heathrow',
 'Heathrow',
 'London',
 'London',
 'London',
 'California',
 'California',
 'Kayseri',
 'Istanbul',
 'Kayseri',
 'Anatolia',
 'Istanbul',
 'German',
 'Iraqi'

In [55]:
# all toponyms
fps, fns = evaluate(data_all_toponyms, processed_results)

fp: 512 | tp: 545 | fn: 511
precision: 0.516 | recall: 0.516 | f-score: 0.516


In [56]:
fps

['White House',
 'New York',
 'Queens',
 'New York City',
 'U',
 '.',
 'S',
 '.',
 'Rose Garden',
 'Southern Poverty Law Center',
 'Ronald Reagan Building',
 'U',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'Islamic State',
 'U',
 '.',
 'S',
 '.',
 'U',
 '.',
 'S',
 '.',
 'Southern',
 'United States',
 'Texas',
 'US',
 'Wisconsin',
 'Wisconsin',
 'Wisconsin',
 'Wisconsin',
 'Pennsylvania',
 'Michigan',
 'Washington',
 'South Carolina',
 'US',
 'Manhattan',
 'South Carolina',
 'Florida',
 'Texas',
 'US',
 'Cuba',
 'Cuba',
 'US',
 'Cuba',
 'America',
 'Michigan',
 'Pennsylvania',
 'Wisconsin',
 'US',
 'US',
 'Russia',
 'US',
 'Wisconsin',
 'Michigan',
 'Pennsylvania',
 'Redding',
 'California',
 'Wooster St.',
 'Cumberland Farms',
 'South Main St',
 'Torrington',
 'Calgary',
 'Calgary',
 '-',
 'North West',
 'Alberta',
 'Kremlin',
 'Kremlin',
 'Canada',
 'Vancouver',
 'Downtown Eastside',
 'US',
 'London City',
 'Heathrow',
 'Gatwick',

In [57]:
fns

['Turkish',
 'Turkish',
 'Syrian',
 'Syrian',
 'U.S.',
 'Turkish',
 'Kurdish',
 'Turkish',
 'Russian',
 'Granville County',
 'New York',
 'Queens',
 'WASHINGTON',
 'U.S.',
 'Texas',
 'Texas',
 'Texas',
 'U.S.',
 'U.S.',
 'U.S.',
 'U.S.',
 'U.S.',
 'Xavier University',
 'British',
 'U.S.',
 'European',
 'U.S.',
 'U.S.',
 'U.S.',
 'DETROIT',
 'U.S.',
 'U.S.',
 'US',
 'Wisconsin',
 'Wisconsin',
 'Wisconsin',
 'New York',
 'Michigan',
 'Washington',
 'South Carolina',
 'US',
 'Manhattan',
 'US',
 'Cuba',
 'Cuba',
 'Cuban',
 'US',
 'Cuba',
 'Michigan',
 'Pennsylvania',
 'Wisconsin',
 'US',
 'US',
 'Russia',
 'US',
 'BANTAM',
 'Bantam',
 'New Milford',
 'TORRINGTON',
 'Torrington',
 'Calgary',
 'Calgary',
 'Russian',
 'Russian',
 'Russian',
 'Russia',
 'Russian',
 'Canada',
 'Vancouver',
 'London',
 'Heathrow',
 'Gatwick',
 'England',
 'Heathrow',
 'Heathrow',
 'London',
 'London',
 'London',
 'California',
 'California',
 'Kayseri',
 'Istanbul',
 'Kayseri',
 'Anatolia',
 'Istanbul',
 'Germa

### Predict single example

In [None]:
import torch

In [None]:
testie_text = 'Hi, I am Francesca and I like to eat tacos in Mexico'

In [None]:
outputs = model(**tokenizer(testie_text, truncation=True, return_tensors='pt').to('cuda'))

In [None]:
input_ids = torch.tensor([tokenized_sentence]).cuda()

In [None]:
# outputs = model(input_ids)

In [None]:
import numpy as np

In [None]:
label_indices = np.argmax(outputs[0].to('cpu').detach().numpy(), axis=2)

In [None]:
# join bpe split tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids.to('cpu').detach().numpy()[0])
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, label_indices[0]):
    if token.startswith("##"):
        new_tokens[-1] = new_tokens[-1] + token[2:]
    else:
        new_labels.append(label_list[label_idx])
        new_tokens.append(token)

In [None]:
for token, label in zip(new_tokens, new_labels):
    print("{}\t{}".format(label, token))