In [1]:
from datasets import load_dataset

In [2]:
# Load the dataset
dataset = load_dataset("yash-iitk/my-ner-dataset")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 28775
    })
    validation: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 9592
    })
    test: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 9592
    })
})

In [4]:
train_df = dataset['train']

In [5]:
for i in train_df:
    tags = i['tags']
    break

In [6]:
records = []
for x in train_df:
    record = [(i,j) for i,j in zip(x['tokens'], x['tags'])]
    records.append(record)

In [7]:
unique_tags = set()
for i in train_df:
    tags = i['tags']
    unique_tags.update(tags)

tag2id = {tag: i for i, tag in enumerate(sorted(unique_tags))}
id2tag = {i: tag for tag, i in tag2id.items()}

In [9]:
from functions import tokenize_and_align_labels, NERDataset, tokenize_for_inference, BERT_NER

In [17]:
import torch

In [14]:
from transformers import BertTokenizerFast, BertModel
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

In [15]:
dataset = NERDataset(records[:2000], tag2id, tokenizer)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [18]:
num_labels = len(tags)
model = BERT_NER(num_labels=num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
model.to(device);

In [20]:
for param in model.bert.parameters():
    param.requires_grad = False

In [22]:
from torch import nn

In [23]:
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=5e-4)

In [None]:
checkpoint_name = "classifier_weights.pth"

In [24]:
from tqdm import tqdm

In [25]:
%%time

EPOCHS = 10

num_labels = len(tags)
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        logits = model(input_ids=input_ids, attention_mask=attention_mask)

        # Reshape for loss: (batch*seq_len, num_labels)
        loss = loss_fn(logits.view(-1, num_labels), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

Epoch 1: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:08<00:00,  3.99it/s]


Epoch 1 Average Loss: 2.2105


Epoch 2: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.26it/s]


Epoch 2 Average Loss: 0.9013


Epoch 3: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.25it/s]


Epoch 3 Average Loss: 0.6169


Epoch 4: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.25it/s]


Epoch 4 Average Loss: 0.5006


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.24it/s]


Epoch 5 Average Loss: 0.4269


Epoch 6: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.22it/s]


Epoch 6 Average Loss: 0.3777


Epoch 7: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.22it/s]


Epoch 7 Average Loss: 0.3423


Epoch 8: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.22it/s]


Epoch 8 Average Loss: 0.3172


Epoch 9: 100%|████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.21it/s]


Epoch 9 Average Loss: 0.2975


Epoch 10: 100%|███████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.21it/s]

Epoch 10 Average Loss: 0.2826
CPU times: user 1min 15s, sys: 361 ms, total: 1min 16s
Wall time: 1min 16s





In [27]:
torch.save(model.classifier.state_dict(), checkpoint_name)