In [1]:
import torch
from torch import nn
from transformers import BertTokenizerFast, BertModel


In [2]:
%%time
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

CPU times: user 29.4 ms, sys: 0 ns, total: 29.4 ms
Wall time: 804 ms


In [3]:
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [4]:
import os
from pathlib import Path
import pandas as pd
import chardet

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

In [5]:
data_file = os.path.join(PROJECT_ROOT, 'local_only', 'ner_dataset.csv')
with open(data_file, 'rb') as f:
    result = chardet.detect(f.read())
    print(result)

{'encoding': 'Windows-1252', 'confidence': 0.73, 'language': ''}


In [6]:
df = pd.read_csv(data_file, encoding='Windows-1252', keep_default_na=False, na_values=[])
df['Sentence #'] = df['Sentence #'].ffill()

data = []
for index, sentence_list in df.groupby('Sentence #'):
    record = [(word, tag) for word,tag in sentence_list[['Word','Tag']].itertuples(index=False)]
    data.append(record)

In [7]:
tags = df['Tag'].unique()
tag2id = {tag: idx for idx, tag in enumerate(tags)}
id2tag = {idx: tag for tag, idx in tag2id.items()}

In [8]:
from torch.utils.data import Dataset, DataLoader
class NERDataset(Dataset):
    def __init__(self, data, tag2idx, tokenizer, max_len=128):
        self.data = data
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence = self.data[idx]
        words = [w for w, t in sentence]
        labels = [self.tag2idx[t] for w, t in sentence]
    
        # Tokenize without tensors so word_ids() works
        encoding = self.tokenizer(words,
                                  is_split_into_words=True,
                                  return_offsets_mapping=False,
                                  truncation=True,
                                  padding='max_length',
                                  max_length=self.max_len)
    
        word_ids = encoding.word_ids()  # No [0] here — it's a list
        aligned_labels = []
    
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                aligned_labels.append(-100)
            elif word_idx != previous_word_idx:
                aligned_labels.append(labels[word_idx])
            else:
                aligned_labels.append(-100)  # subword continuation
            previous_word_idx = word_idx
    
        # Convert everything to tensors
        return {
            "input_ids": torch.tensor(encoding["input_ids"]),
            "attention_mask": torch.tensor(encoding["attention_mask"]),
            "labels": torch.tensor(aligned_labels)
        }

In [9]:
len(data)

47960

In [88]:
len(data)

3

In [10]:
dataset = NERDataset(data[:2000], tag2id, tokenizer)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [11]:
class BERT_NER(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased")
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state)  # [batch, seq_len, num_labels]
        return logits

In [28]:
num_labels = len(tags)
model = BERT_NER(num_labels=num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [31]:
model.to(device);

In [32]:
num_params = 0
for param in model.bert.parameters():
    if len(param.shape) == 2:
        num_params += param.shape[0] * param.shape[1] 
    else:
        num_params += param.shape[0]

In [33]:
num_params

108310272

In [34]:
for param in model.bert.parameters():
    param.requires_grad = False

In [35]:
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=5e-4)

In [36]:
from tqdm import tqdm

In [38]:
%%time

EPOCHS = 20

num_labels = len(tags)
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Reshape for loss: (batch*seq_len, num_labels)
        loss = loss_fn(logits.view(-1, num_labels), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.01it/s]


Epoch 1 Average Loss: 0.5267


Epoch 2: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.15it/s]


Epoch 2 Average Loss: 0.4857


Epoch 3: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 3 Average Loss: 0.4419


Epoch 4: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 4 Average Loss: 0.4305


Epoch 5: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 5 Average Loss: 0.4078


Epoch 6: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 6 Average Loss: 0.3721


Epoch 7: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]


Epoch 7 Average Loss: 0.3633


Epoch 8: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 8 Average Loss: 0.3529


Epoch 9: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 9 Average Loss: 0.3404


Epoch 10: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]


Epoch 10 Average Loss: 0.3326


Epoch 11: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 11 Average Loss: 0.3158


Epoch 12: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 12 Average Loss: 0.3063


Epoch 13: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 13 Average Loss: 0.2910


Epoch 14: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 14 Average Loss: 0.2952


Epoch 15: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]


Epoch 15 Average Loss: 0.2836


Epoch 16: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 16 Average Loss: 0.2741


Epoch 17: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 17 Average Loss: 0.2710


Epoch 18: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.07it/s]


Epoch 18 Average Loss: 0.2726


Epoch 19: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 19 Average Loss: 0.2524


Epoch 20: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]

Epoch 20 Average Loss: 0.2606
CPU times: user 3min 25s, sys: 1.23 s, total: 3min 26s
Wall time: 3min 25s





In [39]:
%%time

EPOCHS = 20

num_labels = len(tags)
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Reshape for loss: (batch*seq_len, num_labels)
        loss = loss_fn(logits.view(-1, num_labels), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 1 Average Loss: 0.2455


Epoch 2: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 2 Average Loss: 0.2455


Epoch 3: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 3 Average Loss: 0.2350


Epoch 4: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 4 Average Loss: 0.2366


Epoch 5: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 5 Average Loss: 0.2330


Epoch 6: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.09it/s]


Epoch 6 Average Loss: 0.2272


Epoch 7: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 7 Average Loss: 0.2302


Epoch 8: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.09it/s]


Epoch 8 Average Loss: 0.2237


Epoch 9: 100%|██████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.08it/s]


Epoch 9 Average Loss: 0.2270


Epoch 10: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]


Epoch 10 Average Loss: 0.2188


Epoch 11: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.15it/s]


Epoch 11 Average Loss: 0.2184


Epoch 12: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.09it/s]


Epoch 12 Average Loss: 0.2218


Epoch 13: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 13 Average Loss: 0.2060


Epoch 14: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 14 Average Loss: 0.2093


Epoch 15: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.05it/s]


Epoch 15 Average Loss: 0.2075


Epoch 16: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 16 Average Loss: 0.2043


Epoch 17: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 17 Average Loss: 0.2081


Epoch 18: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]


Epoch 18 Average Loss: 0.2040


Epoch 19: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 19 Average Loss: 0.1996


Epoch 20: 100%|█████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.11it/s]

Epoch 20 Average Loss: 0.1901
CPU times: user 3min 24s, sys: 1.79 s, total: 3min 26s
Wall time: 3min 25s





In [46]:
data = dataset[1000]

In [54]:
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
labels = data['labels'].to(device)

In [56]:
batch = None
for i in loader:
    batch = i
    break

In [62]:
labels

tensor([[-100,    3, -100,  ..., -100, -100, -100],
        [-100,    0, -100,  ..., -100, -100, -100],
        [-100,    0, -100,  ..., -100, -100, -100],
        ...,
        [-100,    0, -100,  ..., -100, -100, -100],
        [-100,    0, -100,  ..., -100, -100, -100],
        [-100,    0, -100,  ..., -100, -100, -100]], device='cuda:0')

In [63]:
len(labels)

64

In [65]:
# ---------------------
# 5. Prediction Example
# ---------------------
model.eval()
with torch.no_grad():
    #test_input, _, test_len = dataset[0]
    
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    logits = model(input_ids=input_ids, attention_mask=attention_mask)
    """
    pred = torch.argmax(logits, dim=-1).squeeze().tolist()[:64]
    words = [w for (w, t) in data[0]]
    print("\nPredictions:")
    for w, t_idx in zip(words, pred):
        print(f"{w:10} -> {idx2tag[t_idx]}")
    """

In [71]:
len(logits[0])

128

In [72]:
logits[0][0]

tensor([ 1.9520,  0.6237,  0.1692,  3.2879, -5.1774,  0.6500, -4.8683, -2.6278,
        -4.8637, -5.2439, -4.8230, -5.0989, -4.7518, -2.8827, -4.4066, -4.9325,
        -4.7586], device='cuda:0')

In [73]:
torch.argmax(logits[0][0]).item() 

3

In [76]:
predicted_class_name = id2tag[3] 
predicted_class_name

'B-per'

In [80]:
input_ids.size()

torch.Size([64, 128])

In [81]:
logits.size()

torch.Size([64, 128, 17])

In [83]:
logits[0].size()

torch.Size([128, 17])

In [91]:
tokenizer.convert_ids_to_tokens(data['input_ids'])

['[CLS]',
 'C',
 '##lashes',
 '[SEP]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 

In [84]:
logits[0][0]

tensor([ 1.9520,  0.6237,  0.1692,  3.2879, -5.1774,  0.6500, -4.8683, -2.6278,
        -4.8637, -5.2439, -4.8230, -5.0989, -4.7518, -2.8827, -4.4066, -4.9325,
        -4.7586], device='cuda:0')