# NER with DistilBERT and WikiAnn Dataset
This notebook demonstrates how to build a Named Entity Recognition (NER) model using DistilBERT on the WikiAnn dataset.

In [3]:
!pip install transformers datasets torch scikit-learn
# Import necessary libraries
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from transformers import DistilBertTokenizerFast, DistilBertModel, AdamW, get_scheduler
from datasets import load_dataset
from sklearn.metrics import classification_report




In [4]:
# Load the WikiAnn dataset for NER
dataset = load_dataset('wikiann', 'en')

In [5]:
# Initialize the DistilBERT tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

# Tokenize and align labels
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'], truncation=True, padding=True, is_split_into_words=True
    )
    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = [-100 if word_id is None else label[word_id] for word_id in word_ids]
        labels.append(label_ids)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

# Apply tokenization to the dataset
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]



Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [6]:
# Define a custom collate function
def collate_fn(batch):
    input_ids = [torch.tensor(item['input_ids']) for item in batch]
    attention_masks = [torch.tensor(item['attention_mask']) for item in batch]
    labels = [torch.tensor(item['labels']) for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)

    return {'input_ids': input_ids, 'attention_mask': attention_masks, 'labels': labels}

# Create DataLoaders
train_loader = DataLoader(tokenized_dataset['train'], batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(tokenized_dataset['validation'], batch_size=16, collate_fn=collate_fn)

In [7]:
# Define the NER model using DistilBERT
class NERModel(nn.Module):
    def __init__(self, num_labels):
        super(NERModel, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-cased')
        self.classifier = nn.Linear(self.bert.config.dim, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state)
        return logits

# Initialize the model
num_labels = len(dataset['train'].features['ner_tags'].feature.names)
model = NERModel(num_labels)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

NERModel(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_feat

In [8]:
# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_scheduler('linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=1000)



In [9]:
# Define training and evaluation functions
def compute_loss(logits, labels):
    active_loss = labels.view(-1) != -100
    active_logits = logits.view(-1, logits.shape[-1])[active_loss]
    active_labels = labels.view(-1)[active_loss]
    return loss_fn(active_logits, active_labels)

def train_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = compute_loss(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader):
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=-1)
            predictions.extend(preds.cpu().numpy().tolist())
            true_labels.extend(labels.cpu().numpy().tolist())
    return predictions, true_labels

In [10]:
from sklearn.metrics import classification_report

def align_predictions(predictions, labels):
    """Align predictions with true labels by flattening sequences and ignoring -100 padding."""
    pred_list = []
    true_list = []

    for pred, label in zip(predictions, labels):
        # Remove padding (-100) from both predictions and labels
        active_preds = [
            p for (p, l) in zip(pred, label) if l != -100
        ]
        active_labels = [
            l for l in label if l != -100
        ]

        pred_list.extend(active_preds)
        true_list.extend(active_labels)

    return true_list, pred_list

def evaluate_and_report(model, dataloader):
    """Evaluate the model and generate a classification report."""
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Get predictions
            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=-1)

            # Convert predictions and labels to CPU and store them
            all_predictions.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

    # Align predictions and labels
    true_labels, pred_labels = align_predictions(all_predictions, all_labels)

    # Generate classification report
    print(classification_report(true_labels, pred_labels, target_names=dataset["train"].features["ner_tags"].feature.names))


In [11]:
# Evaluate the model after training
evaluate_and_report(model, val_loader)


              precision    recall  f1-score   support

           O       0.56      0.41      0.48     51513
       B-PER       0.09      0.54      0.15      6919
       I-PER       0.10      0.00      0.00     13546
       B-ORG       0.07      0.25      0.11      7707
       I-ORG       0.09      0.00      0.00     15323
       B-LOC       0.06      0.02      0.03      9838
       I-LOC       0.06      0.03      0.04      9225

    accuracy                           0.24    114071
   macro avg       0.15      0.18      0.12    114071
weighted avg       0.30      0.24      0.24    114071

