# Latin-BERT training

In [1]:
!nvidia-smi

Wed Jul 12 12:37:13 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090         Off| 00000000:01:00.0 Off |                  N/A |
| 30%   27C    P8               26W / 350W|      0MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090         Off| 00000000:41:0

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
import torch
torch.cuda.device_count()

1

In [4]:
# !pip install datasets
# !pip install transformers
# !pip install tensorflow
# !pip install tensor2tensor
# !pip install cltk
# !pip install seqeval
# !pip install wandb

## prepare data

In [5]:
import pandas as pd
train = pd.read_csv('./data/Latin_NER_train.csv', index_col=0)
test = pd.read_csv('./data/Latin_NER_test.csv', index_col=0)
val = pd.read_csv('./data/Latin_NER_eval.csv', index_col=0)

In [6]:
train['tag'].value_counts()

O         82696
B-PERS     2706
B-GRP      1271
B-LOC       839
I-PERS      618
I-LOC        31
I-GRP         4
Name: tag, dtype: int64

In [7]:
import json

label2idx = {'O': 0, 
 'B-PERS': 1, 
 'I-PERS': 2, 
 'B-LOC': 3, 
 'I-LOC': 4, 
 'B-GRP': 5, 
 'I-GRP': 6}

idx2label = {value: key for key, value in label2idx.items()}

In [8]:
from datasets import load_dataset

dataset = load_dataset('data/Latin_NER_json')

Found cached dataset json (/home/pricie/marijkeb/.cache/huggingface/datasets/json/Latin_NER_json-d0854115d27feb27/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/3 [00:00<?, ?it/s]

## preparing the tokenizer

In [9]:
# download nessecary cltk models
# !pip install cltk
# import cltk
# from cltk.data.fetch import FetchCorpus
# corpus_downloader = FetchCorpus(language='lat')
# corpus_downloader.import_corpus('lat_models_cltk')


In [10]:
# install tensor to tensor encoder
# !pip install tensor2tensor --user
# !pip install tensorflow

# from cltk.tokenizers.lat.lat import LatinWordTokenizer as WordTokenizer
# from cltk.tokenizers.lat.lat import LatinPunktSentenceTokenizer as SentenceTokenizer
from tensor2tensor.data_generators import text_encoder

2023-07-12 12:37:32.291555: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
#copied this class and function from the latin-BERT repo
#[REDACTED]'s code
#made some adjustments

from transformers import BatchEncoding

class LatinTokenizer():
	def __init__(self, encoder):
		self.vocab={}
		self.reverseVocab={}
		self.encoder=encoder

		self.vocab["[PAD]"]=0
		self.vocab["[UNK]"]=1
		self.vocab["[CLS]"]=2
		self.vocab["[SEP]"]=3
		self.vocab["[MASK]"]=4
		self.model_max_length=256
		self.is_fast=False


		self.cls_token_id = self.vocab["[CLS]"]
		self.pad_token_id = self.vocab["[PAD]"]
		self.sep_token_id = self.vocab["[SEP]"]
        
		for key in self.encoder._subtoken_string_to_id:
			self.vocab[key]=self.encoder._subtoken_string_to_id[key]+5
			self.reverseVocab[self.encoder._subtoken_string_to_id[key]+5]=key


	def convert_tokens_to_ids(self, tokens):
		wp_tokens=[]
		for token in tokens:
			if token == "[PAD]":
				wp_tokens.append(0)
			elif token == "[UNK]":
				wp_tokens.append(1)
			elif token == "[CLS]":
				wp_tokens.append(2)
			elif token == "[SEP]":
				wp_tokens.append(3)
			elif token == "[MASK]":
				wp_tokens.append(4)

			else:
				wp_tokens.append(self.vocab[token])

		return wp_tokens

	def tokenize(self, text, split_on_tokens=True):
		if split_on_tokens:
			tokens = [token.lower() if token not in ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] else token for token in text]
		else: 
			tokens = text.split()

		wp_tokens=[] #word-piece tokens

		for token in tokens:
			# print(token)

			if token in {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}:
				wp_tokens.append(token)
			else:

				wp_toks=self.encoder.encode(token)

				for wp in wp_toks:
					wp_tokens.append(self.reverseVocab[wp+5])

		return wp_tokens
	
	def calculate_attention_masks(self, wp_tokens):
		attention_masks = []
		
		for token in wp_tokens:
			if token == self.pad_token_id:
				attention_masks.append(0)
			else:
				attention_masks.append(1)
				
		return attention_masks
	
	def pad(self, features, padding=True, max_length=256, pad_to_multiple_of="", return_tensors=True):
		# TODO
		batch_outputs = {}
		
		for i in range(len(features)):
			for key, value in features[i].items():
	
				if key in batch_outputs:
					batch_outputs[key].append(value)
	
				else:
					batch_outputs[key] = [value]

		for k, v in batch_outputs.items():
			batch_outputs[k] = torch.tensor([x for x in v])

		return BatchEncoding(batch_outputs)
	
	def pad_max_length_and_add_specials_tokens_also(self, tokens, wp_tokens):

		MAX_LENGTH = 256
		wp_tokens.insert(0, self.cls_token_id)
		tokens.insert(0, '[CLS]')
		wp_tokens.append(self.sep_token_id)
		tokens.append('[SEP]')
		
		if len(wp_tokens) > 256:
			wp_tokens = wp_tokens[:256]
			tokens = tokens[:256]
		
		else:
			while len(wp_tokens) < 256:
				wp_tokens.append(self.pad_token_id)
				tokens.append('[PAD]')

		return tokens, wp_tokens
	
	def pad_max_length_and_add_specials(self, wp_tokens):

		MAX_LENGTH = 256
		wp_tokens.insert(0, self.cls_token_id)
		wp_tokens.append(self.sep_token_id)
		
		if len(wp_tokens) > 256:
			wp_tokens = wp_tokens[:256]
		
		else:
			while len(wp_tokens) < 256:
				wp_tokens.append(self.pad_token_id)

		return wp_tokens
	
	def decode_to_string(self, input_ids):
		tokens = [self.reverseVocab[x] for x in input_ids if x > 4]
		return "".join(tokens).replace('_', ' ')

	def save_pretrained(self, output_dir):
		pass




# def convert_to_toks(sents):

# 	sent_tokenizer = SentenceTokenizer()
# 	word_tokenizer = WordTokenizer()

# 	all_sents=[]

# 	for data in sents:
# 		text=data.lower()

# 		sents=sent_tokenizer.tokenize(text)
# 		for sent in sents:
# 			tokens=word_tokenizer.tokenize(sent)
# 			filt_toks=[]
# 			filt_toks.append("[CLS]")
# 			for tok in tokens:
# 				if tok != "":
# 					filt_toks.append(tok)
# 			filt_toks.append("[SEP]")

# 			all_sents.append(filt_toks)

# 	return all_sents

# def df_to_toks(df, sent_column="sentence_ids", word_column="words"):

# 	all_sents = []

# 	grouped = df.groupby(sent_column)
	
# 	for sent in grouped.groups:
# 		sent_df = grouped.get_group(sent)
		
# 		tokens = sent_df[word_column].values.tolist()
		
# 		filt_toks=[]
		
# 		filt_toks.append("[CLS]")
# 		for tok in tokens:
# 			if tok != "":
# 				filt_toks.append(tok)
# 		filt_toks.append("[SEP]")

# 		all_sents.append(filt_toks)

# 	return all_sents

In [12]:
#load the tokenizer
tokenizer = LatinTokenizer(text_encoder.SubwordTextEncoder('../latin-bert/models/subword_tokenizer_latin/latin.subword.encoder'))

test_sentence = train.groupby('sentence').get_group(1)['word'].values.tolist()

tokens = tokenizer.tokenize(test_sentence)

print(tokens)

#the output is unusual for huggingface, but it's in the vocab file this way so I leave it.

['ut_', 'vero_', 'ex_', 'litteris_', 'ad_', 'senatum_', 'referre', 'tur_', ',_', 'impetra', 'ri_', 'non_', 'potuit_', '._', 'hi_', 'omnes_', 'lingua_', ',_', 'institutis_', ',_', 'legibus_', 'inter_', 'se_', 'differunt_', '._', '"_', 'namque_', 'tu_', 'sole', 'bas_', 'nuga', 's_', 'esse_', 'aliquid_', 'meas_', 'putare_', ',_', '"_', 'ut_', 'obit', 'er_', 'emo', 'llia', 'm_', 'catull', 'um_', 'conter', 'rane', 'um_', 'meum_', '(_', 'agnosci', 's_', 'et_', 'hoc_', 'castr', 'ense_', 'verbum_', ')_', 'ille_', 'enim_', ',_', 'ut_', 'scis_', ',_', 'permutat', 'is_', 'prioribus_', 'syllab', 'is_', 'duri', 'uscul', 'um_', 'se_', 'fecit_', 'quam_', 'volebat_', 'existima', 'ri_', ',_', 'a_', 'vera', 'nio', 'lis_', 'suis_', 'et_', 'fabul', 'lis_', '._']


In [13]:
import re


def tokenize_adjust_labels(all_samples_per_split):
	
	pretokenized_samples = all_samples_per_split["tokens"]
	tokenized_samples = tokenizer.tokenize(pretokenized_samples) #create wordpiece tokens
	token_ids = tokenizer.convert_tokens_to_ids(tokenized_samples) #to ids
    #path both the tokens and the the token_ids, as the tokenids are on subwords
	padded_tokenized_samples, padded_token_ids = tokenizer.pad_max_length_and_add_specials_tokens_also(tokenized_samples, token_ids) #pad and add special tokens

	all_samples_per_split['input_ids'] = padded_token_ids
    
	all_samples_per_split['attention_mask'] = tokenizer.calculate_attention_masks(padded_token_ids)
	all_samples_per_split['wp_tokens'] = tokenized_samples
	all_samples_per_split['extra'] = padded_tokenized_samples

	#original
	orig_labels = all_samples_per_split['tags']


	# logic to adjust labels, 
	adjusted_labels = []
	label_idx = 0
	# print(len(pretokenized_samples))

	for token in padded_tokenized_samples:
		try:
            #The tokenizer always treats punctuation as a separate token
            #in most cases, this is not a problem as the punctuation is also seperately labeled in GWannotation, 
            #but there are a few exceptions
            #next statement catches those
            
			if token in ['[CLS]', '[SEP]', '[PAD]']:
				adjusted_labels.append(-100)
                
			elif re.match(r'\w+[\.\,]', pretokenized_samples[label_idx]):
				if token != '._' and token != ',_':
					adjusted_labels.append(orig_labels[label_idx])
				else:
					adjusted_labels.append(orig_labels[label_idx])
					label_idx += 1
		
			elif token.endswith('_'):
				adjusted_labels.append(orig_labels[label_idx])
				label_idx += 1

			else:
				adjusted_labels.append(orig_labels[label_idx])
		except IndexError:
			try :
				if token in ['[CLS]', '[SEP]', '[PAD]']:
					adjusted_labels.append(-100)
				else:
					adjusted_labels.append(orig_labels[label_idx])
			except IndexError:
				print('HERE')
				print(pretokenized_samples[:label_idx-1])
				print(orig_labels[:label_idx-1])
				print(token)
				print(list(zip(padded_tokenized_samples, adjusted_labels)))


	all_samples_per_split['labels'] = adjusted_labels

	try:
		assert len(adjusted_labels) == len(padded_tokenized_samples) == 256
	except AssertionError:
		print(all_samples_per_split)


	return all_samples_per_split


tokenized_dataset = dataset.map(tokenize_adjust_labels)

Map:   0%|          | 0/5815 [00:00<?, ? examples/s]

Map:   0%|          | 0/972 [00:00<?, ? examples/s]

Map:   0%|          | 0/3410 [00:00<?, ? examples/s]

## trainer test

In [14]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
import numpy as np
from datasets import load_metric
from sklearn.metrics import classification_report
from sklearn.preprocessing import MultiLabelBinarizer

metric = load_metric("seqeval")
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [idx2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [idx2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    flattened_results = {
        "overall_precision": results["overall_precision"],
        "overall_recall": results["overall_recall"],
        "overall_f1": results["overall_f1"],
        "overall_accuracy": results["overall_accuracy"],
    }
    return flattened_results

  metric = load_metric("seqeval")


In [15]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [16]:
# !pip install protobuf==3.20.* --user
# model = AutoModelForTokenClassification.from_pretrained('../latin-bert/models')
model = AutoModelForTokenClassification.from_pretrained('Herodotos_trained_lat_BERT_worked')


In [17]:
training_args = TrainingArguments(
    output_dir="./fine_tune_bert_output",
    seed=123
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [18]:
def extend_clear_list(temp, fixed, item):
    temp.append(item)
    fixed.append(int(np.mean(temp)))
    temp.clear()


def aggregate_ents(original_tokens, wp_tokens, preds, labels):
    #THE FUNCTION WORKS
    #Aggregates subword
    try:
        assert len(wp_tokens) == len(preds) and len(wp_tokens) == len(labels)
    except AssertionError:
        print('lenght tokens, predictions and labels are not equal')
        print(wp_tokens)
        
    fixed_preds = []
    fixed_labels = []
    
    temp_label = []
    temp_pred = []
    
    for i in range(len(wp_tokens)-1):
        if (wp_tokens[i+1] == '._') and (wp_tokens[i] != '._'):
            
            #check if word in orig_tokens is a \w+[,.] token (M. or H,)
            
            #if the original token IS one of the special cases, treat current token as not a comp
            if re.match(r'\w+\.', original_tokens[len(fixed_preds)]):
                temp_label.append(labels[i])
                temp_pred.append(preds[i])
            
            else:
                extend_clear_list(temp_label, fixed_labels, labels[i])
                extend_clear_list(temp_pred, fixed_preds, preds[i])
                
        elif wp_tokens[i+1] == ',_' and (wp_tokens[i] != ',_'):
            
            #check if word in orig_tokens is a \w+[,.] token (M. or H,)
            
            #if the original token IS one of the special cases, treat current token as not a comp
            if re.match(r'\w+\,', original_tokens[len(fixed_preds)]):
                temp_label.append(labels[i])
                temp_pred.append(preds[i])
            
            else:
                extend_clear_list(temp_label, fixed_labels, labels[i])
                extend_clear_list(temp_pred, fixed_preds, preds[i])
            
    
        elif wp_tokens[i].endswith('_') and len(temp_label) == 0:
            fixed_preds.append(preds[i])
            fixed_labels.append(labels[i])
            
        elif wp_tokens[i].endswith('_'):
            extend_clear_list(temp_label, fixed_labels, labels[i])
            extend_clear_list(temp_pred, fixed_preds, preds[i])

        else:
            temp_label.append(labels[i])
            temp_pred.append(preds[i])
            
            
    fixed_preds.append(preds[len(wp_tokens)-1])
    fixed_labels.append(labels[len(wp_tokens)-1])
    
            
    try:        
        assert len(original_tokens) == len(fixed_preds) and len(original_tokens) == len(fixed_labels)
    except AssertionError:
        original_tokens = original_tokens[:len(fixed_preds)]
        try:
            assert len(original_tokens) == len(fixed_preds) and len(original_tokens) == len(fixed_labels)
        except AssertionError:
            print('lenght of original tokens, aggregated predictions and labels are not equal')
            print(f'''originals = {original_tokens} \n 
                  tokenized = {wp_tokens}\n
                  predictions = {preds}\n
                  labels = {labels}\n
                  fixed_preds = {fixed_preds}
                  fixed_labels = {fixed_labels}''')
    
    return original_tokens, fixed_preds, fixed_labels

## test set

In [19]:
predictions = trainer.predict(tokenized_dataset["test"])

import numpy as np

preds = np.argmax(predictions.predictions, axis=-1)

true_predictions = [
        [p for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(preds, predictions.label_ids)
    ]
    
true_labels = [
        [l for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(preds, predictions.label_ids)
    ]

tokens = tokenized_dataset['test']['extra']
orig_tokens = tokenized_dataset['test']['tokens']
ids = tokenized_dataset['test']['id']

dct = {
    'orig_tokens_all': [],
    'agg_predictions_all': [],
    'agg_labels_all': [],
    'agg_ent_predictions_all': [],
    'agg_ent_labels_all': [],
    'all_ids': []
}

major_l = list(zip(ids, orig_tokens, tokens, true_predictions, true_labels))

for idd, original_tokens, wp_tokens, preds, labels in major_l:
    try:
        wp_tokens = [token for token in wp_tokens if token not in ['[CLS]', '[PAD]', '[SEP]']] 
        orig_tokens, fixed_preds, fixed_labels = aggregate_ents(original_tokens, wp_tokens, preds, labels)
        dct['orig_tokens_all'].append(orig_tokens)
        dct['agg_predictions_all'].append(fixed_preds)
        dct['agg_labels_all'].append(fixed_labels)
        dct['agg_ent_predictions_all'].append([idx2label[pred] for pred in fixed_preds])
        dct['agg_ent_labels_all'].append([idx2label[label] for label in fixed_labels])
        dct['all_ids'].append([idd] * len(fixed_labels))
    except AssertionError:
        print('\nOh no!\n')
        print(original_tokens)
        print(wp_tokens)
        pass

The following columns in the test set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tags, tokens, id, extra, wp_tokens. If tags, tokens, id, extra, wp_tokens are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 3410
  Batch size = 8


In [20]:
from seqeval.metrics import classification_report

print(classification_report(dct['agg_ent_labels_all'], dct['agg_ent_predictions_all']))

              precision    recall  f1-score   support

         GRP       0.87      0.71      0.78       354
         LOC       0.77      0.71      0.74       305
        PERS       0.85      0.83      0.84       849

   micro avg       0.84      0.78      0.81      1508
   macro avg       0.83      0.75      0.79      1508
weighted avg       0.84      0.78      0.81      1508



In [21]:
unnested_dct = {
    key: sum(value, []) for key, value in dct.items()
}

In [22]:
for key, value in unnested_dct.items():
    print(len(value))

31788
31788
31788
31788
31788
31788


In [23]:
for key, value in dct.items():
    print(len(value))

3410
3410
3410
3410
3410
3410


In [24]:
dct['orig_tokens_all'][-1]

['Nam',
 'nos',
 'quoque',
 'tam',
 'numerosum',
 'agmen',
 'reorum',
 'ita',
 'demum',
 'videbamus',
 'posse',
 'superari',
 ',',
 'si',
 'per',
 'singulos',
 'carperetur',
 '.']

In [25]:
df = pd.DataFrame.from_dict(unnested_dct)

In [26]:
df

Unnamed: 0,orig_tokens_all,agg_predictions_all,agg_labels_all,agg_ent_predictions_all,agg_ent_labels_all,all_ids
0,timere,0,0,O,O,CW_11
1,Caesarem,1,1,B-PERS,B-PERS,CW_11
2,ereptis,0,0,O,O,CW_11
3,ab,0,0,O,O,CW_11
4,eo,0,0,O,O,CW_11
...,...,...,...,...,...,...
31783,si,0,0,O,O,PlinyYounger_1335
31784,per,0,0,O,O,PlinyYounger_1335
31785,singulos,0,0,O,O,PlinyYounger_1335
31786,carperetur,0,0,O,O,PlinyYounger_1335


In [27]:
from sklearn.metrics import classification_report

print(classification_report(df.agg_ent_labels_all, df.agg_ent_predictions_all))

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       B-GRP       0.88      0.72      0.79       354
       B-LOC       0.80      0.72      0.76       305
      B-PERS       0.88      0.84      0.86       849
       I-GRP       0.00      0.00      0.00         3
       I-LOC       0.25      0.25      0.25         8
      I-PERS       0.78      0.91      0.84        99
           O       0.99      1.00      1.00     30170

    accuracy                           0.99     31788
   macro avg       0.66      0.63      0.64     31788
weighted avg       0.99      0.99      0.99     31788



In [32]:
df['text'] = df['all_ids'].apply(lambda x: x.split('_')[0])
df['sentence'] = df['all_ids'].apply(lambda x: x.split('_')[1])

In [33]:
df['text'].value_counts()

Ovid            17102
GW               7452
PlinyElder       4188
PlinyYounger     2454
CW                592
Name: text, dtype: int64

In [34]:
df['domain'] = df['text'].apply(lambda x: 'IN' if x != 'Ovid' else 'OUT')

In [35]:
df = df.rename(columns = {'orig_tokens_all': 'token',
                    'agg_ent_predictions_all': 'predictions',
                    'agg_ent_labels_all': 'labels',
                    'all_ids': 'sentence_ids',
                    'agg_predictions_all': 'prediction_id',
                    'agg_labels_all': 'label_id'})

In [36]:
df = df[['token', 'label_id', 'prediction_id', 'labels', 'predictions', 'sentence_ids', 'text', 'sentence', 'domain']]

In [37]:
df

Unnamed: 0,token,label_id,prediction_id,labels,predictions,sentence_ids,text,sentence,domain
0,timere,0,0,O,O,CW_11,CW,11,IN
1,Caesarem,1,1,B-PERS,B-PERS,CW_11,CW,11,IN
2,ereptis,0,0,O,O,CW_11,CW,11,IN
3,ab,0,0,O,O,CW_11,CW,11,IN
4,eo,0,0,O,O,CW_11,CW,11,IN
...,...,...,...,...,...,...,...,...,...
31783,si,0,0,O,O,PlinyYounger_1335,PlinyYounger,1335,IN
31784,per,0,0,O,O,PlinyYounger_1335,PlinyYounger,1335,IN
31785,singulos,0,0,O,O,PlinyYounger_1335,PlinyYounger,1335,IN
31786,carperetur,0,0,O,O,PlinyYounger_1335,PlinyYounger,1335,IN


In [295]:
# df.to_csv('errors_test_set_herodotos_FINAL.csv')

### check for match with csv

In [59]:
val = pd.read_csv('./data/Latin_NER_test.csv', index_col=0)

#the following is necessary because the eval.csv is not in ascending sentence order and the json is
grouped = val.groupby('sentence')

val2 = pd.DataFrame(columns = ['word', 'tag', 'sentence'])

for name, group in grouped:
    val2 = pd.concat([val2, group])


assert val2['tag'].values.tolist() == unnested_dct['agg_ent_labels_all']
assert val2['word'].values.tolist() == unnested_dct['orig_tokens_all']