In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments, DataCollatorForTokenClassification
import pandas as pd
from datasets import Dataset
from tqdm import trange, tqdm
import evaluate
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
label_names = ['O', 'B-NAME_STUDENT', 'I-NAME_STUDENT', 'B-URL_PERSONAL',
       'B-EMAIL', 'B-ID_NUM', 'I-URL_PERSONAL', 'B-USERNAME',
       'B-PHONE_NUM', 'I-PHONE_NUM', 'B-STREET_ADDRESS',
       'I-STREET_ADDRESS']
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}

In [4]:
model_name =  "distilbert/distilbert-base-cased"
model = AutoModelForTokenClassification.from_pretrained(model_name, id2label = id2label, label2id = label2id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    for word_id in word_ids:
        if word_id is None:
            new_labels.append(0)
        else:
            new_labels.append(labels[word_id])
            # if(labels[word_id] != 0):
            #     print(labels[word_id])
    return new_labels

In [6]:
def tokenize_and_align(example):
    tokenized_inputs = tokenizer(example["tokens"], truncation=True, is_split_into_words=True, padding= True, max_length = 512)
    all_labels = example["tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))
    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [7]:
def tag2num(tags):
    return [label2id[tag] for tag in tags]

In [8]:
total_df = pd.DataFrame({'tokens': [], 'tags': []})
file_name = 'train.csv'
df = pd.read_csv(file_name).dropna()
df = df.drop(columns=['Unnamed: 0'])
doc_ids = list(df["doc_id"].unique())
size = len(doc_ids)*9//10
train_doc_ids = doc_ids[0:size]
eval_doc_ids = doc_ids[size:]
for doc_id in tqdm(train_doc_ids):
    tokens = df[df['doc_id'] == doc_id]['token'].to_list()
    tags = df[df['doc_id'] == doc_id]['label'].to_list()
    tags = tag2num(tags)
    cur = 0
    while(cur + 200 < len(tags)):
        total_df.loc[len(total_df)] = [tokens[cur:cur+200], tags[cur:cur+200]]
        cur+=200
    total_df.loc[len(total_df)] = [tokens[cur:], tags[cur:]]


train_dataset = Dataset.from_pandas(total_df).remove_columns('__index_level_0__')

total_eval_df = pd.DataFrame({'tokens': [], 'tags': []})

for doc_id in tqdm(eval_doc_ids):
    eval_tokens = df[df['doc_id'] == doc_id]['token'].to_list()
    eval_tags = df[df['doc_id'] == doc_id]['label'].to_list()
    eval_tags = tag2num(eval_tags)
    cur = 0
    while(cur + 200 < len(tags)):
        total_eval_df.loc[len(total_eval_df)] = [eval_tokens[cur:cur+200], eval_tags[cur:cur+200]]
        cur+=200
    total_eval_df.loc[len(total_eval_df)] = [eval_tokens[cur:], eval_tags[cur:]]

eval_dataset = Dataset.from_pandas(total_eval_df).remove_columns('__index_level_0__')

100%|██████████| 3246/3246 [00:14<00:00, 229.84it/s]
100%|██████████| 361/361 [00:01<00:00, 283.76it/s]


In [9]:
train_dataset = train_dataset.map(tokenize_and_align, batched = True,remove_columns=['tokens', 'tags'])
eval_dataset = eval_dataset.map(tokenize_and_align, batched =True, remove_columns=['tokens', 'tags'])

Map:   7%|▋         | 1000/13647 [00:00<00:08, 1505.58 examples/s]

1
1
1
2
2
2
1
1
1
2
2
2
1
1
1
2
2
2
1
2
2
2
1
2
2
2
1
1
2
2
2
1
1
2
2
1
1
2
1
1
1
2
2
1
1
1
2
2
2
1
1
1
2
1
2
2
2
1
2
2
2
1
1
1
1
2
2
2
1
1
1
1
1
2
1
2
2
2
1
1
2
1
2
2
1
2
2
1
2
2
1
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
1
2
2
1
1
2
1
2
2
1
2
2
1
1
2
1
1
1
1
2
2
4
4
4
4
4
4
4
4
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
1
2
2
1
1
1
1
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
2
1
1
2
2
2
1
2
5
5
5
5
5
5
5
5
1
1
1
1
2
2
2
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
1
2
2
2
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
1
2
5
5
5
5
5
5
5
5
1
1
1
2
2
2
5
5
5
5
5
5
5
1
2
2
2
5
5
5
5
5
5
5
5
1
2
5
5
5
5
5
5
5
5
1
1
2
1
1
1
1
1
1
1
2
2
1
2
2
2
2
1
2
2
2
2
1
1
2
2
1
1
1
1
2
1
2
2
1
2
2
1
2
2
1
1
2
2
1
2
2
2
1
1
1
2
2
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
2
1
1
1
2
1
2
1
2
2
2
2
1
1
2
2
1
1
2
2
1
1
2
2
1
2
2
1
2
2
1
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
2
2
2
2
1
1
2
2
1
1
2
2
1
1
1
2
1
2
2
2
1
1
1
1
2
1
1
2
2
2
2
1
2
1
1
2
1
1
2
1
1
1
2
1
1
1
2
2
2
1
2
2
2
2
1
1
2
2
2
1
2
1
1
1
1
1
1
2
2


Map:  15%|█▍        | 2000/13647 [00:01<00:06, 1758.10 examples/s]

1
2
2
2
1
1
1
2
2
2
2
4
4
4
4
4
4
4
4
4
4
4
8
9
9
9
9
9
9
9
1
2
1
2
1
2
1
1
1
1
2
1
1
2
1
2
1
1
2
1
2
2
1
2
2
1
2
2
1
2
2
2
2
2
2
1
2
2
2
2
2
2
1
1
2
5
5
5
5
5
5
5
5
1
1
2
1
1
1
2
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
1
2
1
1
1
2
2
1
2
2
2
2
1
2
2
2
2
1
2
1
2
1
1
1
1
2
2
2
1
2
5
5
5
5
5
5
5
5
5
1
1
2
5
5
5
5
5
5
5
5
5
1
1
2
5
5
5
5
5
5
5
5
1
2
1
1
1
2
2
2
2
1
1
1
1
2
2
2
1
1
1
1
2
2
2
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
2
1
2
1
2
1
2
2
1
2
5
5
5
5
5
5
5
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
1
1
1
2
1
1
2
2
2
2
1
1
2
2
2
2
1
1
1
2
2
2
1
2
2
1
2
1
2
1
2
1
1
1
1
2
2
2
1
1
2
2
1
1
2
2
1
1
2
2
1
1
2
2
1
1
2
2
1
1
2
2
1
1
1
1
1
2
1
2
2
1
2
2
2
1
1
1
2
2
2
1
1
2
2
1
1
1
2
1
1
2
2
2
4
4
4
4
4
4
4
4
4
1
2
2
1
2
1
2
1
1
2
1
1
2
1
2
2
2
1
1
2
1
1
2
1
2
1
2
1
2
1
2
1
1
2
1
1
1
2
4
4
4
4
4
4
4
4
4
4
1
1
1
1
2
1
2
2
2
1
2
1
2
2
1
2
1
2
2
1
1
2
2
1
2
1
1
1
1
2
2
1
2
2
2
2
1
1
1
1
2
2
1
1
2
2
1
1
1
2
2
1
1
2
2
1
1
2
1
2
1
2
1
2
2
1
2
2
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
2


Map:  22%|██▏       | 3000/13647 [00:01<00:05, 1946.71 examples/s]

1
1
1
2
2
1
2
2
1
1
1
1
2
2
2
1
1
1
1
2
2
2
1
2
1
2
1
2
2
1
2
2
1
2
2
2
2
2
1
1
2
2
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
2
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
2
2
1
1
2
2
2
1
2
2
1
2
1
1
1
2
2
2
1
1
1
2
2
2
1
1
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
2
2
2
1
2
2
1
2
2
1
2
1
1
2
1
2
1
2
1
2
1
2
2
2
2
1
2
1
2
2
2
1
1
1
2
2
2
1
1
1
2
2
2
1
1
1
2
2
2
1
2
1
2
1
2
1
1
1
2
2
2
1
1
2
2
2
1
1
1
2
1
1
2
1
1
2
2
2
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
1
1
1
2
2
2
1
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
2
1
2
2
2
1
2
2
2
1
2
2
2
1
2
2
2
1
2
2
1
2
2
1
1
2
2
1
2
1
2
1
2
1
1
2
1
1
2
1
2
1
2
1
2
1
2
2
2
2
7
7
7
7
7
1
1
2
1
1
1
2
2
2
1
1
1
2
2
2
1
2
1
2
2
2
2
1
2
1
2
1
1
2
1
1
2
1
2
2
2
1
1
2
1
2
2
1
2
2
1
2
1
2
1
2
1
2
1
2
2
1
2
2
1
1
1
2
1
1
2
2
1
1
2
2
2
1
1
2
2
2
1
1
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
2
1
1
2
1
1
1
2
2
2
1
1
1
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
3
3
3


Map:  29%|██▉       | 4000/13647 [00:02<00:04, 2038.59 examples/s]

1
1
2
2
1
1
1
2
2
1
1
2
2
2
1
2
2
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
2
2
1
2
2
1
1
1
2
2
1
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
2
2
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
1
1
2
2
2
1
1
2
2
2
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
1
1
1
2
2
2
1
2
1
1
1
2
2
1
1
1
2
2
1
1
1
2
2
1
1
1
2
2
1
2
1
2
1
1
1
2
1
2
1
1
2
2
2
1
1
2
1
2
1
1
2
1
1
2
1
1
1
1
1
1
2
1
1
2
1
1
1
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
1
1
1
1
2
1
2
1
2
1
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
2
2
1
1
1
1
1
1
1
1
2
2
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
3
3
3
3
3
3
3
3
3
3
3
3
1
2
2
1
2
1
2


Map:  37%|███▋      | 5000/13647 [00:02<00:04, 2124.88 examples/s]

1
1
1
2
2
1
1
2
1
2
1
2
1
2
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
2
2
2
1
1
1
2
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
2
1
2
1
1
2
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
2
2
2
5
5
5
5
5
5
5
5
5
5
5
5
5
1
2
2
1
1
1
2
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
1
2
2
1
2
1
2
1
2
2
2
2
1
1
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
1
1
2
1
1
2
2
2
1
2
2
1
1
1
2
1
2
2
1
1
2
2
5
5
5
5
5
5
1
1
2
2
5
5
5
5
5
5
1
1
5
5
5
5
5
5
5
5
1
1
5
5
5
5
5
5
5
1
5
5
5
5
5
5
5
5
1
1
2
1
1
1
1
1
1
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
2
1
1
2
2
2
1
2
1
1
1
2
2
2
10
10
11
11
11
11
11
11
11
11
11
11
11
11
11
11
11
11
8
8
8
8
8
8
8
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
2
2
2
1
2
2
2
1
2
2
2
1
1
1
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2


Map:  44%|████▍     | 6000/13647 [00:02<00:03, 2129.86 examples/s]

1
1
1
1
1
1
1
2
2
1
2
2
2
1
1
1
1
2
2
1
1
2
2
2
1
2
2
2
1
2
2
1
2
2
1
1
2
2
1
1
1
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
1
2
2
1
2
2
2
1
1
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
1
2
1
2
1
2
1
2
2
1
1
2
2
1
2
1
2
1
1
2
2
1
1
1
2
2
2
2
1
2
2
2
1
2
2
1
2
1
2
2
2
1
2
2
2
1
1
2
1
2
1
2
1
2
1
2
1
2
1
1
1
1
2
2
2
2
1
1
2
1
1
2
1
2
1
1
1
2


Map:  51%|█████▏    | 7000/13647 [00:03<00:03, 2176.40 examples/s]

1
1
2
1
2
2
2
1
1
1
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
2
2
1
2
2
1
1
2
2
5
5
5
5
5
1
2
1
1
2
2
2
1
2
2
2
2
1
2
2
5
5
5
1
1
2
1
2
1
2
2
1
2
2
1
2
1
2
1
2
1
2
1
2
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
1
2
2
2
1
1
1
2
2
2
1
1
2
2
1
2
1
1
2
1
1
2
1
2
1
2
1
2
1
1
2
2
1
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
1
2
1
1
1
2
1
1
1
2
1
1
2
1
1
2
1
1
1
2
2
1
2
2
1
1
2
2
2
1
1
1
2
2
2
2


Map:  59%|█████▊    | 8000/13647 [00:03<00:02, 2177.62 examples/s]

4
4
4
4
4
4
4
4
4


Map:  73%|███████▎  | 10000/13647 [00:04<00:01, 2232.45 examples/s]

5
5
5
5
5
5
5
5


Map: 100%|██████████| 13647/13647 [00:06<00:00, 2158.80 examples/s]
Map: 100%|██████████| 1444/1444 [00:00<00:00, 1932.04 examples/s]


In [10]:
data_collator = DataCollatorForTokenClassification(tokenizer = tokenizer)

In [11]:
metric = evaluate.load("seqeval")

In [12]:
training_args = TrainingArguments(
    evaluation_strategy = "epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    output_dir="./results",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [13]:
trainer.train()

  0%|          | 0/5118 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
results = trainer.predict(eval_dataset).predictions

In [None]:
results[0][0]

In [None]:
trainer.predict(eval_dataset)

In [None]:
correct_labels = []
predict_labels = []
for text_pos, text in enumerate(tqdm(results)):
    labels = np.argmax(text, axis = 1)
    predict_labels.extend(labels)
    correct_labels.extend(eval_dataset['labels'][text_pos])

In [None]:
from sklearn.metrics import f1_score

f1_score(correct_labels, predict_labels, average = "micro")

In [None]:
for text_pos, text in enumerate(results):
    for position,token in enumerate(text):
      label = id2label[np.argmax(token)]
      if label != 'O':
        print(tokenizer.convert_ids_to_tokens(eval_dataset['input_ids'][text_pos])[position], label)

In [None]:
from collections import Counter
# Counter(predict_labels).keys()
Counter(correct_labels).keys()