### Import Libraries and Load Dataset

In [2]:
import torch
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification, DataCollatorForTokenClassification
from transformers import Trainer, TrainingArguments
from transformers import EarlyStoppingCallback
from datasets import load_dataset, Dataset
import evaluate

# Load dataset
dataset = load_dataset("telord/mountains-ner-dataset")

# Load seqeval for evaluation
seqeval = evaluate.load("seqeval")

Downloading readme:   0%|          | 0.00/578 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/811k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/106k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/109k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3827 [00:00<?, ? examples/s]

Generating val split:   0%|          | 0/478 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/479 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

### Tokenizing the Data

In [3]:
# Load the fast tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

label_all_tokens = True 

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)
    labels = []
    for i, label in enumerate(examples['labels']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Apply tokenization and alignment to the dataset
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]



Map:   0%|          | 0/3827 [00:00<?, ? examples/s]

Map:   0%|          | 0/478 [00:00<?, ? examples/s]

Map:   0%|          | 0/479 [00:00<?, ? examples/s]

In [4]:
tokenized_datasets['train']

Dataset({
    features: ['sentence', 'tokens', 'labels', 'input_ids', 'attention_mask'],
    num_rows: 3827
})

### Define the Model

In [5]:
# Define label list for mountain NER
label_list = ["O", "B-mountain", "I-mountain"]

# Creating label2id and id2label mappings
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}

# Output the mappings
print("label2id:", label2id)
print("id2label:", id2label)

label2id: {'O': 0, 'B-mountain': 1, 'I-mountain': 2}
id2label: {0: 'O', 1: 'B-mountain', 2: 'I-mountain'}


In [6]:
# Load the DistilBERT model for token classification
model = DistilBertForTokenClassification.from_pretrained(
    'distilbert-base-cased',
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Define Metrics

In [7]:
def compute_metrics(p):
    predictions, labels = p
    predictions = predictions.argmax(axis=2)
    
    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    
    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

### Seting Up Training Arguments

In [8]:
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch',
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=15,
    weight_decay=0.01,
    save_total_limit=1,
    logging_dir='./logs',
    logging_steps=10,
    load_best_model_at_end=True,  # Load the best model at the end of training
    metric_for_best_model="accuracy",  # Metric to track for early stopping
    greater_is_better=True,  # Whether higher metric values indicate better performance
)


In [9]:
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

### Data Collator for Token Classification

In [10]:
# Use DataCollatorForTokenClassification for dynamic padding and batching
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

### Initializing Trainer

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['val'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    callbacks=[early_stopping],
)

### Training the Model

In [12]:
# Use GPU if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device);

In [13]:
trainer.train()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.1309,0.10606,0.590625,0.747036,0.659686,0.960875
2,0.0619,0.080223,0.738386,0.795784,0.766011,0.972138
3,0.0562,0.084651,0.802721,0.777339,0.789826,0.973521
4,0.0419,0.084558,0.795892,0.816864,0.806242,0.975761
5,0.0205,0.093601,0.780976,0.822134,0.801027,0.975102
6,0.0256,0.098681,0.777506,0.837945,0.806595,0.975168
7,0.0117,0.110843,0.812,0.802372,0.807157,0.975234


TrainOutput(global_step=840, training_loss=0.061114613556613524, metrics={'train_runtime': 181.2013, 'train_samples_per_second': 316.802, 'train_steps_per_second': 9.934, 'total_flos': 579190305903750.0, 'train_loss': 0.061114613556613524, 'epoch': 7.0})

### Saving the Model

In [14]:
model.save_pretrained('./ner_model')
tokenizer.save_pretrained('./ner_model')

('./ner_model/tokenizer_config.json',
 './ner_model/special_tokens_map.json',
 './ner_model/vocab.txt',
 './ner_model/added_tokens.json',
 './ner_model/tokenizer.json')

### Eval on test

In [15]:
test_dataset = tokenized_datasets['test']

# Evaluate the model on the test set
test_results = trainer.evaluate(test_dataset)

# Print the results
print("Test set evaluation results:")
for key, value in test_results.items():
    print(f"{key}: {value:.4f}")

Test set evaluation results:
eval_loss: 0.0903
eval_precision: 0.8028
eval_recall: 0.8085
eval_f1: 0.8056
eval_accuracy: 0.9731
eval_runtime: 1.0102
eval_samples_per_second: 474.1840
eval_steps_per_second: 7.9200
epoch: 7.0000


### Run Inference

In [16]:
def predict(text):
    tokens = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
    tokens = {key: val.to(device) for key, val in tokens.items()}  # Move tokens to GPU if available
    output = model(**tokens)
    logits = output.logits
    predictions = torch.argmax(logits, dim=2)
    return predictions


In [17]:
text = "At Heartbreak Hill the field thinned."
predictions = predict(text)
print(predictions)

tensor([[0, 0, 1, 1, 2, 0, 0, 0, 0, 0, 0]], device='cuda:0')


In [18]:
#Formating inference results to better interpretable
def format_inference_results(text, predictions, tokenizer, id2label):
    # Tokenize input text to get the tokenization alignment
    tokens = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
    tokens_decoded = tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])
    
    # Convert predictions from tensor to list
    predictions = predictions[0].cpu().numpy()
    
    # Map the predictions to the corresponding labels
    predicted_labels = [id2label[pred] for pred in predictions]
    
    # Combine tokens with their corresponding labels
    result = []
    for token, label in zip(tokens_decoded, predicted_labels):
        # Ignore special tokens like [CLS] and [SEP]
        if token not in tokenizer.all_special_tokens:
            result.append((token, label))
    
    return result

In [19]:
# Example usage
text = "Treetops Hotel is a hotel in Aberdare National Park in Kenya near the township of Nyeri , 1,966 m ( 6,450 ft ) above sea level on the Aberdare Range and in sight of Mount Kenya ."
tokens = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
tokens = {key: val.to(device) for key, val in tokens.items()}  # Ensure tokens are on the correct device
output = model(**tokens)
logits = output.logits
predictions = torch.argmax(logits, dim=2)

# Call the function to format results
formatted_results = format_inference_results(text, predictions, tokenizer, id2label)

# Print the formatted results
for token, label in formatted_results:
    print(f"{token}: {label}")


Tree: O
##top: O
##s: O
Hotel: O
is: O
a: O
hotel: O
in: O
Abe: O
##rda: O
##re: O
National: O
Park: O
in: O
Kenya: O
near: O
the: O
township: O
of: O
N: O
##yer: O
##i: O
,: O
1: O
,: O
96: O
##6: O
m: O
(: O
6: O
,: O
450: O
ft: O
): O
above: O
sea: O
level: O
on: O
the: O
Abe: B-mountain
##rda: B-mountain
##re: B-mountain
Range: I-mountain
and: O
in: O
sight: O
of: O
Mount: B-mountain
Kenya: I-mountain
.: O
