In [None]:
%%capture
pip install transformers

In [None]:

import pandas as pd
import torch 
import numpy as np
from transformers import BertTokenizerFast, BertForTokenClassification
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import SGD

In [None]:
df = pd.read_csv('/content/drive/MyDrive/data/ner.csv')
df.head()

Unnamed: 0,text,labels
0,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
1,Iranian officials say they expect to get acces...,B-gpe O O O O O O O O O O O O O O B-tim O O O ...
2,Helicopter gunships Saturday pounded militant ...,O O B-tim O O O O O B-geo O O O O O B-org O O ...
3,They left after a tense hour-long standoff wit...,O O O O O O O O O O O
4,U.N. relief coordinator Jan Egeland said Sunda...,B-geo O O B-per I-per O B-tim O B-geo O B-gpe ...


In [None]:
# Split labels based on whitespace and turn them into a list
labels = [i.split() for i in df['labels'].values.tolist()]

# Check how many labels are there in the dataset
unique_labels = set()

for lb in labels:
  [unique_labels.add(i) for i in lb if i not in unique_labels]
 
print(unique_labels)

{'B-art', 'I-org', 'O', 'B-org', 'I-art', 'I-gpe', 'B-geo', 'I-tim', 'B-eve', 'B-tim', 'I-geo', 'I-eve', 'I-nat', 'B-nat', 'B-gpe', 'B-per', 'I-per'}


In [None]:
# Map each label into its id representation and vice versa
labels_to_ids = {k: v for v, k in enumerate(sorted(unique_labels))}
ids_to_labels = {v: k for v, k in enumerate(sorted(unique_labels))}
print(labels_to_ids)

{'B-art': 0, 'B-eve': 1, 'B-geo': 2, 'B-gpe': 3, 'B-nat': 4, 'B-org': 5, 'B-per': 6, 'B-tim': 7, 'I-art': 8, 'I-eve': 9, 'I-geo': 10, 'I-gpe': 11, 'I-nat': 12, 'I-org': 13, 'I-per': 14, 'I-tim': 15, 'O': 16}


In [None]:
# Let's take a look at how can we preprocess the text - Take first example
text = df['text'].values.tolist()
example = text[36]

print(example)

Prime Minister Geir Haarde has refused to resign or call for early elections .


In [None]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
text_tokenized = tokenizer(example, padding='max_length', max_length=512, truncation=True, return_tensors="pt")


Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
print(text_tokenized)

{'input_ids': tensor([[  101,  3460,  2110,   144,  6851,  1197, 11679,  2881,  1162,  1144,
          3347,  1106, 13133,  1137,  1840,  1111,  1346,  3212,   119,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

In [None]:
print(tokenizer.decode(text_tokenized.input_ids[0]))

[CLS] Prime Minister Geir Haarde has refused to resign or call for early elections. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

In [None]:
print(tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]))

['[CLS]', 'Prime', 'Minister', 'G', '##ei', '##r', 'Ha', '##ard', '##e', 'has', 'refused', 'to', 'resign', 'or', 'call', 'for', 'early', 'elections', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]',

In [None]:
word_ids = text_tokenized.word_ids()
print(tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]))
print(word_ids)

['[CLS]', 'Prime', 'Minister', 'G', '##ei', '##r', 'Ha', '##ard', '##e', 'has', 'refused', 'to', 'resign', 'or', 'call', 'for', 'early', 'elections', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]',

In [None]:
def align_label_example(tokenized_input, labels):

        word_ids = tokenized_input.word_ids()

        previous_word_idx = None
        label_ids = []
   
        for word_idx in word_ids:

            if word_idx is None:
                label_ids.append(-100)
                
            elif word_idx != previous_word_idx:
                try:
                  label_ids.append(labels_to_ids[labels[word_idx]])
                except:
                  label_ids.append(-100)
        
            else:
                label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)
            previous_word_idx = word_idx
      

        return label_ids

In [None]:
label = labels[36]

#If we set label_all_tokens to True.....
label_all_tokens = True

new_label = align_label_example(text_tokenized, label)
print(new_label)
print(tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]))

[-100, 16, 16, 6, 6, 6, 14, 14, 14, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 

In [None]:
#If we set label_all_tokens to False.....
label_all_tokens = False

new_label = align_label_example(text_tokenized, label)
print(new_label)
print(tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]))

[-100, 16, 16, 6, -100, -100, 14, -100, -100, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -1

In [None]:
import torch

def align_label(texts, labels):
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(labels_to_ids[labels[word_idx]])
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids

class DataSequence(torch.utils.data.Dataset):

    def __init__(self, df):

        lb = [i.split() for i in df['labels'].values.tolist()]
        txt = df['text'].values.tolist()
        self.texts = [tokenizer(str(i),
                               padding='max_length', max_length = 512, truncation=True, return_tensors="pt") for i in txt]
        self.labels = [align_label(i,j) for i,j in zip(txt, lb)]

    def __len__(self):

        return len(self.labels)

    def get_batch_data(self, idx):

        return self.texts[idx]

    def get_batch_labels(self, idx):

        return torch.LongTensor(self.labels[idx])

    def __getitem__(self, idx):

        batch_data = self.get_batch_data(idx)
        batch_labels = self.get_batch_labels(idx)

        return batch_data, batch_labels


In [None]:
import numpy as np

df = df[0:1000]
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42),
                            [int(.8 * len(df)), int(.9 * len(df))])

In [None]:

from transformers import BertForTokenClassification

class BertModel(torch.nn.Module):

    def __init__(self):

        super(BertModel, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))

    def forward(self, input_id, mask, label):

        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)

        return output

In [None]:
def train_loop(model, df_train, df_val):

    train_dataset = DataSequence(df_train)
    val_dataset = DataSequence(df_val)

    train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=BATCH_SIZE)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    optimizer = SGD(model.parameters(), lr=LEARNING_RATE)

    if use_cuda:
        model = model.cuda()

    best_acc = 0
    best_loss = 1000

    for epoch_num in range(EPOCHS):

        total_acc_train = 0
        total_loss_train = 0

        model.train()

        for train_data, train_label in tqdm(train_dataloader):

            train_label = train_label.to(device)
            mask = train_data['attention_mask'].squeeze(1).to(device)
            input_id = train_data['input_ids'].squeeze(1).to(device)

            optimizer.zero_grad()
            loss, logits = model(input_id, mask, train_label)

            for i in range(logits.shape[0]):

              logits_clean = logits[i][train_label[i] != -100]
              label_clean = train_label[i][train_label[i] != -100]

              predictions = logits_clean.argmax(dim=1)
              acc = (predictions == label_clean).float().mean()
              total_acc_train += acc
              total_loss_train += loss.item()

            loss.backward()
            optimizer.step()

        model.eval()

        total_acc_val = 0
        total_loss_val = 0

        for val_data, val_label in val_dataloader:

            val_label = val_label.to(device)
            mask = val_data['attention_mask'].squeeze(1).to(device)
            input_id = val_data['input_ids'].squeeze(1).to(device)

            loss, logits = model(input_id, mask, val_label)

            for i in range(logits.shape[0]):

              logits_clean = logits[i][val_label[i] != -100]
              label_clean = val_label[i][val_label[i] != -100]

              predictions = logits_clean.argmax(dim=1)
              acc = (predictions == label_clean).float().mean()
              total_acc_val += acc
              total_loss_val += loss.item()

        val_accuracy = total_acc_val / len(df_val)
        val_loss = total_loss_val / len(df_val)

        print(
            f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .3f} | Accuracy: {total_acc_train / len(df_train): .3f} | Val_Loss: {total_loss_val / len(df_val): .3f} | Accuracy: {total_acc_val / len(df_val): .3f}')

LEARNING_RATE = 5e-3
EPOCHS = 100
BATCH_SIZE = 2

model = BertModel()
train_loop(model, df_train, df_val)


Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

Epochs: 1 | Loss:  0.621 | Accuracy:  0.849 | Val_Loss:  0.451 | Accuracy:  0.888


100%|██████████| 400/400 [01:12<00:00,  5.48it/s]


Epochs: 2 | Loss:  0.444 | Accuracy:  0.881 | Val_Loss:  0.387 | Accuracy:  0.905


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 3 | Loss:  0.382 | Accuracy:  0.897 | Val_Loss:  0.362 | Accuracy:  0.910


100%|██████████| 400/400 [01:13<00:00,  5.44it/s]


Epochs: 4 | Loss:  0.343 | Accuracy:  0.907 | Val_Loss:  0.344 | Accuracy:  0.908


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 5 | Loss:  0.314 | Accuracy:  0.914 | Val_Loss:  0.334 | Accuracy:  0.915


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 6 | Loss:  0.285 | Accuracy:  0.921 | Val_Loss:  0.311 | Accuracy:  0.925


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 7 | Loss:  0.260 | Accuracy:  0.927 | Val_Loss:  0.328 | Accuracy:  0.927


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 8 | Loss:  0.233 | Accuracy:  0.934 | Val_Loss:  0.303 | Accuracy:  0.921


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 9 | Loss:  0.210 | Accuracy:  0.937 | Val_Loss:  0.296 | Accuracy:  0.919


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 10 | Loss:  0.190 | Accuracy:  0.943 | Val_Loss:  0.313 | Accuracy:  0.925


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 11 | Loss:  0.174 | Accuracy:  0.948 | Val_Loss:  0.282 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 12 | Loss:  0.157 | Accuracy:  0.954 | Val_Loss:  0.309 | Accuracy:  0.929


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 13 | Loss:  0.137 | Accuracy:  0.958 | Val_Loss:  0.289 | Accuracy:  0.927


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 14 | Loss:  0.130 | Accuracy:  0.962 | Val_Loss:  0.348 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 15 | Loss:  0.111 | Accuracy:  0.967 | Val_Loss:  0.350 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 16 | Loss:  0.100 | Accuracy:  0.968 | Val_Loss:  0.341 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 17 | Loss:  0.100 | Accuracy:  0.969 | Val_Loss:  0.351 | Accuracy:  0.934


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 18 | Loss:  0.088 | Accuracy:  0.974 | Val_Loss:  0.358 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 19 | Loss:  0.075 | Accuracy:  0.978 | Val_Loss:  0.346 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 20 | Loss:  0.069 | Accuracy:  0.980 | Val_Loss:  0.351 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 21 | Loss:  0.062 | Accuracy:  0.982 | Val_Loss:  0.392 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 22 | Loss:  0.056 | Accuracy:  0.983 | Val_Loss:  0.394 | Accuracy:  0.934


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 23 | Loss:  0.048 | Accuracy:  0.987 | Val_Loss:  0.428 | Accuracy:  0.933


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 24 | Loss:  0.041 | Accuracy:  0.989 | Val_Loss:  0.382 | Accuracy:  0.937


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 25 | Loss:  0.038 | Accuracy:  0.990 | Val_Loss:  0.392 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 26 | Loss:  0.038 | Accuracy:  0.989 | Val_Loss:  0.418 | Accuracy:  0.929


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 27 | Loss:  0.033 | Accuracy:  0.991 | Val_Loss:  0.441 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 28 | Loss:  0.031 | Accuracy:  0.992 | Val_Loss:  0.439 | Accuracy:  0.929


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 29 | Loss:  0.025 | Accuracy:  0.994 | Val_Loss:  0.502 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 30 | Loss:  0.022 | Accuracy:  0.995 | Val_Loss:  0.502 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 31 | Loss:  0.020 | Accuracy:  0.996 | Val_Loss:  0.492 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 32 | Loss:  0.017 | Accuracy:  0.996 | Val_Loss:  0.444 | Accuracy:  0.926


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 33 | Loss:  0.019 | Accuracy:  0.995 | Val_Loss:  0.483 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 34 | Loss:  0.019 | Accuracy:  0.996 | Val_Loss:  0.479 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 35 | Loss:  0.020 | Accuracy:  0.995 | Val_Loss:  0.491 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 36 | Loss:  0.016 | Accuracy:  0.996 | Val_Loss:  0.463 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 37 | Loss:  0.016 | Accuracy:  0.996 | Val_Loss:  0.484 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 38 | Loss:  0.015 | Accuracy:  0.997 | Val_Loss:  0.496 | Accuracy:  0.929


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 39 | Loss:  0.012 | Accuracy:  0.997 | Val_Loss:  0.513 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 40 | Loss:  0.012 | Accuracy:  0.998 | Val_Loss:  0.503 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 41 | Loss:  0.010 | Accuracy:  0.998 | Val_Loss:  0.508 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 42 | Loss:  0.013 | Accuracy:  0.998 | Val_Loss:  0.510 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 43 | Loss:  0.010 | Accuracy:  0.998 | Val_Loss:  0.541 | Accuracy:  0.930


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 44 | Loss:  0.010 | Accuracy:  0.998 | Val_Loss:  0.533 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 45 | Loss:  0.009 | Accuracy:  0.998 | Val_Loss:  0.530 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 46 | Loss:  0.008 | Accuracy:  0.998 | Val_Loss:  0.525 | Accuracy:  0.932


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 47 | Loss:  0.009 | Accuracy:  0.998 | Val_Loss:  0.512 | Accuracy:  0.934


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 48 | Loss:  0.008 | Accuracy:  0.998 | Val_Loss:  0.560 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 49 | Loss:  0.007 | Accuracy:  0.999 | Val_Loss:  0.518 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 50 | Loss:  0.006 | Accuracy:  0.999 | Val_Loss:  0.554 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 51 | Loss:  0.006 | Accuracy:  0.999 | Val_Loss:  0.596 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 52 | Loss:  0.008 | Accuracy:  0.998 | Val_Loss:  0.568 | Accuracy:  0.939


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 53 | Loss:  0.007 | Accuracy:  0.999 | Val_Loss:  0.544 | Accuracy:  0.937


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 54 | Loss:  0.004 | Accuracy:  0.999 | Val_Loss:  0.555 | Accuracy:  0.934


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 55 | Loss:  0.006 | Accuracy:  0.999 | Val_Loss:  0.566 | Accuracy:  0.933


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 56 | Loss:  0.005 | Accuracy:  0.999 | Val_Loss:  0.558 | Accuracy:  0.931


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 57 | Loss:  0.006 | Accuracy:  0.999 | Val_Loss:  0.582 | Accuracy:  0.934


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 58 | Loss:  0.005 | Accuracy:  0.999 | Val_Loss:  0.596 | Accuracy:  0.936


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 59 | Loss:  0.006 | Accuracy:  0.999 | Val_Loss:  0.572 | Accuracy:  0.937


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 60 | Loss:  0.005 | Accuracy:  0.999 | Val_Loss:  0.529 | Accuracy:  0.933


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 61 | Loss:  0.004 | Accuracy:  0.999 | Val_Loss:  0.551 | Accuracy:  0.938


100%|██████████| 400/400 [01:13<00:00,  5.45it/s]


Epochs: 62 | Loss:  0.004 | Accuracy:  1.000 | Val_Loss:  0.585 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.47it/s]


Epochs: 63 | Loss:  0.005 | Accuracy:  0.999 | Val_Loss:  0.602 | Accuracy:  0.936


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 64 | Loss:  0.005 | Accuracy:  0.999 | Val_Loss:  0.586 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


Epochs: 65 | Loss:  0.005 | Accuracy:  0.999 | Val_Loss:  0.615 | Accuracy:  0.935


100%|██████████| 400/400 [01:13<00:00,  5.46it/s]


In [None]:
def evaluate(model, df_test):

    test_dataset = DataSequence(df_test)

    test_dataloader = DataLoader(test_dataset, num_workers=4, batch_size=1)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:
        model = model.cuda()

    total_acc_test = 0.0

    for test_data, test_label in test_dataloader:

            test_label = test_label.to(device)
            mask = test_data['attention_mask'].squeeze(1).to(device)

            input_id = test_data['input_ids'].squeeze(1).to(device)

            loss, logits = model(input_id, mask, test_label)

            for i in range(logits.shape[0]):

              logits_clean = logits[i][test_label[i] != -100]
              label_clean = test_label[i][test_label[i] != -100]

              predictions = logits_clean.argmax(dim=1)
              acc = (predictions == label_clean).float().mean()
              total_acc_test += acc

    val_accuracy = total_acc_test / len(df_test)
    print(f'Test Accuracy: {total_acc_test / len(df_test): .3f}')


evaluate(model, df_test)

Test Accuracy:  0.933


In [None]:
def align_word_ids(texts):
  
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(1 if label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids


def evaluate_one_text(model, sentence):


    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:
        model = model.cuda()

    text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")

    mask = text['attention_mask'].to(device)
    input_id = text['input_ids'].to(device)
    label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)

    logits = model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]
    print(sentence)
    print(prediction_label)
            
evaluate_one_text(model, 'Mark Zuckerberg is one of the founders of Facebook, a company from the United States')


Mark Zuckerberg is one of the founders of Facebook, a company from the United States
['B-per', 'I-per', 'O', 'O', 'O', 'O', 'O', 'O', 'B-org', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'I-geo']
