In [1]:
import torch
from torch import nn
from transformers import BertTokenizerFast, BertModel

In [2]:
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [3]:
%%time
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

CPU times: user 31.2 ms, sys: 10.6 ms, total: 41.8 ms
Wall time: 918 ms


In [4]:
import os
from pathlib import Path
import pandas as pd
import chardet

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

In [6]:
data_file = os.path.join(PROJECT_ROOT, 'local_only', 'ner_dataset.csv')
with open(data_file, 'rb') as f:
    encoding = chardet.detect(f.read())
    print(encoding)

{'encoding': 'Windows-1252', 'confidence': 0.73, 'language': ''}


In [41]:
df = pd.read_csv(data_file, encoding=encoding['encoding'], keep_default_na=False, na_values=[])
df['Sentence #']  = df['Sentence #'].replace('', None)
df['Sentence #'] = df['Sentence #'].ffill()

word_tag_list = []
for index, sentence_list in df.groupby('Sentence #'):
    record = [(word, tag) for word,tag in sentence_list[['Word','Tag']].itertuples(index=False)]
    word_tag_list.append(record)

In [42]:
df.head()

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,Sentence: 1,of,IN,O
2,Sentence: 1,demonstrators,NNS,O
3,Sentence: 1,have,VBP,O
4,Sentence: 1,marched,VBN,O


In [46]:
df.columns

Index(['Sentence #', 'Word', 'POS', 'Tag'], dtype='object')

In [47]:
tags = df['Tag'].unique()
tag2id = {tag: idx for idx, tag in enumerate(tags)}
id2tag = {idx: tag for tag, idx in tag2id.items()}

In [48]:
import torch
from torch.utils.data import Dataset

def tokenize_and_align_labels(sentence, tokenizer, tag2idx, max_len):
    words = [w for w, t in sentence]
    labels = [tag2idx[t] for w, t in sentence]

    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_offsets_mapping=True,
        truncation=True,
        padding='max_length',
        max_length=max_len
    )

    word_ids = encoding.word_ids()
    aligned_labels = []

    previous_word_idx = None
    for word_idx in word_ids:
        if word_idx is None:
            aligned_labels.append(-100)
        elif word_idx != previous_word_idx:
            aligned_labels.append(labels[word_idx])
        else:
            aligned_labels.append(-100)
        previous_word_idx = word_idx

    return {
        "input_ids": torch.tensor(encoding["input_ids"]),
        "attention_mask": torch.tensor(encoding["attention_mask"]),
        "labels": torch.tensor(aligned_labels)
    }

class NERDataset(Dataset):
    def __init__(self, data, tag2idx, tokenizer, max_len=128):
        self.data = data
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence = self.data[idx]
        return tokenize_and_align_labels(sentence, self.tokenizer, self.tag2idx, self.max_len)


In [49]:
def tokenize_for_inference(words, tokenizer, max_len):
    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_offsets_mapping=True,
        truncation=True,
        padding='max_length',
        max_length=max_len
    )

    return {
        "input_ids": torch.tensor([encoding["input_ids"]]),
        "attention_mask": torch.tensor([encoding["attention_mask"]]),
        "word_ids": encoding.word_ids()  # useful for decoding output
    }


In [50]:
class BERT_NER(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased")
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state)  # [batch, seq_len, num_labels]
        return logits

In [51]:
len(word_tag_list)

47959

In [52]:
dataset = NERDataset(word_tag_list[:2000], tag2id, tokenizer)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [53]:
num_labels = len(tags)
model = BERT_NER(num_labels=num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [54]:
model.to(device);

In [55]:
num_params = 0
for param in model.bert.parameters():
    if len(param.shape) == 2:
        num_params += param.shape[0] * param.shape[1] 
    else:
        num_params += param.shape[0]

In [56]:
num_params

108310272

In [57]:
for param in model.bert.parameters():
    param.requires_grad = False

In [58]:
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=5e-4)

In [59]:
from tqdm import tqdm

In [60]:
%%time

EPOCHS = 20

num_labels = len(tags)
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Reshape for loss: (batch*seq_len, num_labels)
        loss = loss_fn(logits.view(-1, num_labels), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.05it/s]


Epoch 1 Average Loss: 1.7173


Epoch 2: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.21it/s]


Epoch 2 Average Loss: 0.7625


Epoch 3: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.22it/s]


Epoch 3 Average Loss: 0.5798


Epoch 4: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.22it/s]


Epoch 4 Average Loss: 0.4800


Epoch 5: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.13it/s]


Epoch 5 Average Loss: 0.4172


Epoch 6: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.17it/s]


Epoch 6 Average Loss: 0.3712


Epoch 7: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.19it/s]


Epoch 7 Average Loss: 0.3366


Epoch 8: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.19it/s]


Epoch 8 Average Loss: 0.3093


Epoch 9: 100%|██████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.11it/s]


Epoch 9 Average Loss: 0.2913


Epoch 10: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.16it/s]


Epoch 10 Average Loss: 0.2756


Epoch 11: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.18it/s]


Epoch 11 Average Loss: 0.2632


Epoch 12: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.18it/s]


Epoch 12 Average Loss: 0.2510


Epoch 13: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.09it/s]


Epoch 13 Average Loss: 0.2459


Epoch 14: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.15it/s]


Epoch 14 Average Loss: 0.2364


Epoch 15: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.16it/s]


Epoch 15 Average Loss: 0.2330


Epoch 16: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.18it/s]


Epoch 16 Average Loss: 0.2229


Epoch 17: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.08it/s]


Epoch 17 Average Loss: 0.2187


Epoch 18: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.11it/s]


Epoch 18 Average Loss: 0.2118


Epoch 19: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.14it/s]


Epoch 19 Average Loss: 0.2095


Epoch 20: 100%|█████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.16it/s]

Epoch 20 Average Loss: 0.2063
CPU times: user 2min 33s, sys: 591 ms, total: 2min 34s
Wall time: 2min 34s





In [64]:
words = ' '.join([word for word, tag in word_tag_list[11000]])
new_records = tokenize_for_inference([words], tokenizer=tokenizer, max_len=128)

In [65]:
words

"The FAO 's estimate includes damage to fishing industries in Indonesia , Maldives , Somalia , Sri Lanka and Thailand ."

In [80]:
new_records

{'input_ids': tensor([[  101,  1109,  6820,  2346,   112,   188, 10301,  2075,  3290,  1106,
           5339,  7519,  1107,  5572,   117, 18880, 27943,   117, 15350,   117,
           4471,  6722,  1105,  5872,   119,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,    

In [79]:
new_records['word_ids']

[None,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [76]:
sentence = [word for word, tag in word_tag_list[11000]]
features = tokenize_for_inference(sentence, tokenizer, max_len=128)
# Move inputs to the model's device
input_ids = features["input_ids"].to(device)
attention_mask = features["attention_mask"].to(device)

model.eval()
with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs#.logits
    predictions = torch.argmax(logits, dim=2).squeeze().tolist()

# Map predictions back to words
word_ids = features["word_ids"]
final_tags = []
prev_word_idx = None
for idx, word_idx in enumerate(word_ids):
    if word_idx is None or word_idx == prev_word_idx:
        continue
    final_tags.append((sentence[word_idx], id2tag[predictions[idx]]))
    prev_word_idx = word_idx

print(final_tags)
# → [('Barack', 'B-PER'), ('Obama', 'I-PER'), ('visited', 'O'), ('Paris', 'B-LOC')]


[('The', 'O'), ('FAO', 'B-org'), ("'s", 'O'), ('estimate', 'O'), ('includes', 'O'), ('damage', 'O'), ('to', 'O'), ('fishing', 'O'), ('industries', 'O'), ('in', 'O'), ('Indonesia', 'B-geo'), (',', 'O'), ('Maldives', 'B-geo'), (',', 'O'), ('Somalia', 'B-geo'), (',', 'O'), ('Sri', 'B-geo'), ('Lanka', 'I-geo'), ('and', 'O'), ('Thailand', 'B-geo'), ('.', 'O')]


In [85]:
import nltk
nltk.download('punkt_tab')  

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [86]:
from nltk.tokenize import word_tokenize

In [91]:
sentence = '''Britain's prime minister said "ABC XYZ"'''
words =  word_tokenize(sentence)

In [92]:
words

['Britain', "'s", 'prime', 'minister', 'said', '``', 'ABC', 'XYZ', "''"]

In [68]:
tokenizer.convert_ids_to_tokens(new_records['input_ids'][0])

['[CLS]',
 'The',
 'FA',
 '##O',
 "'",
 's',
 'estimate',
 'includes',
 'damage',
 'to',
 'fishing',
 'industries',
 'in',
 'Indonesia',
 ',',
 'Mal',
 '##dives',
 ',',
 'Somalia',
 ',',
 'Sri',
 'Lanka',
 'and',
 'Thailand',
 '.',
 '[SEP]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]'