In [1]:
import torch
import json
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


### Training

In [2]:
tokenizer = AutoTokenizer.from_pretrained("dslim/distilbert-NER")

label_list = ["O", "B-MOUNT", "I-MOUNT"]
label_to_id = {label: i for i, label in enumerate(label_list)}
id_to_label = {i: label for i, label in enumerate(label_list)}

# Convert labels to IDs
def encode_labels(labels):
    ids = [label_to_id[label] for label in labels]
    padding_length = tokenizer.model_max_length - len(ids)
    ids += [label_list.index('O')] * padding_length
    return ids

def encode_tokens(tokens):
    ids = tokenizer.convert_tokens_to_ids(tokens)
    padding_length = tokenizer.model_max_length - len(ids)
    ids += [tokenizer.pad_token_id] * padding_length
    return ids

def prepare_dataset(file_path):
    """
    Function to read the json file with data and convert
    it to a Dataset object with encoded labels and tokens
    """
    with open(file_path, "r") as f:
        dataset = json.load(f)
    dataset = Dataset.from_list(dataset)
    dataset = dataset.map(lambda x: {'labels': encode_labels(x['labels']),
                                     'input_ids': encode_tokens(x['tokens'])})
    return dataset

In [3]:
train_dataset = prepare_dataset("train_dataset.json")
val_dataset = prepare_dataset("val_dataset.json")

Map: 100%|██████████| 1000/1000 [00:00<00:00, 3064.36 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 2857.00 examples/s]


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [5]:
model = AutoModelForTokenClassification.from_pretrained("dslim/distilbert-NER", num_labels=len(label_list), ignore_mismatched_sizes=True).to(device)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at dslim/distilbert-NER and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([9]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([9, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

In [7]:
# Train the model
trainer.train()

                                                  
 10%|█         | 63/630 [06:56<48:10,  5.10s/it]

{'eval_loss': 0.002114757662639022, 'eval_runtime': 2.8179, 'eval_samples_per_second': 35.488, 'eval_steps_per_second': 2.484, 'epoch': 1.0}


                                                  
 20%|██        | 126/630 [13:53<42:40,  5.08s/it]

{'eval_loss': 0.0008995987591333687, 'eval_runtime': 2.7405, 'eval_samples_per_second': 36.49, 'eval_steps_per_second': 2.554, 'epoch': 2.0}


                                                 
 30%|███       | 189/630 [20:50<37:16,  5.07s/it]

{'eval_loss': 0.0007847411907278001, 'eval_runtime': 2.7604, 'eval_samples_per_second': 36.227, 'eval_steps_per_second': 2.536, 'epoch': 3.0}


                                                 
 40%|████      | 252/630 [27:47<31:59,  5.08s/it]

{'eval_loss': 0.0007230451446957886, 'eval_runtime': 2.7648, 'eval_samples_per_second': 36.168, 'eval_steps_per_second': 2.532, 'epoch': 4.0}


                                                 
 50%|█████     | 315/630 [34:44<26:40,  5.08s/it]

{'eval_loss': 0.000689782144036144, 'eval_runtime': 2.7596, 'eval_samples_per_second': 36.237, 'eval_steps_per_second': 2.537, 'epoch': 5.0}


                                                 
 60%|██████    | 378/630 [41:41<21:20,  5.08s/it]

{'eval_loss': 0.000546968134585768, 'eval_runtime': 2.7599, 'eval_samples_per_second': 36.233, 'eval_steps_per_second': 2.536, 'epoch': 6.0}


                                                 
 70%|███████   | 441/630 [48:38<16:01,  5.09s/it]

{'eval_loss': 0.0008653226541355252, 'eval_runtime': 2.7623, 'eval_samples_per_second': 36.202, 'eval_steps_per_second': 2.534, 'epoch': 7.0}


 79%|███████▉  | 500/630 [55:11<14:25,  6.66s/it]

{'loss': 0.0048, 'grad_norm': 0.0009105164790526032, 'learning_rate': 4.126984126984127e-06, 'epoch': 7.94}


                                                 
 80%|████████  | 504/630 [55:36<10:51,  5.17s/it]

{'eval_loss': 0.0006288138101808727, 'eval_runtime': 2.7556, 'eval_samples_per_second': 36.289, 'eval_steps_per_second': 2.54, 'epoch': 8.0}


                                                   
 90%|█████████ | 567/630 [1:02:33<05:19,  5.08s/it]

{'eval_loss': 0.0006100510363467038, 'eval_runtime': 2.7601, 'eval_samples_per_second': 36.231, 'eval_steps_per_second': 2.536, 'epoch': 9.0}


                                                   
100%|██████████| 630/630 [1:09:31<00:00,  6.62s/it]

{'eval_loss': 0.000589806295465678, 'eval_runtime': 2.5816, 'eval_samples_per_second': 38.735, 'eval_steps_per_second': 2.711, 'epoch': 10.0}
{'train_runtime': 4171.1197, 'train_samples_per_second': 2.397, 'train_steps_per_second': 0.151, 'train_loss': 0.003826779875135611, 'epoch': 10.0}





TrainOutput(global_step=630, training_loss=0.003826779875135611, metrics={'train_runtime': 4171.1197, 'train_samples_per_second': 2.397, 'train_steps_per_second': 0.151, 'total_flos': 1306554624000000.0, 'train_loss': 0.003826779875135611, 'epoch': 10.0})

In [8]:
# Save the fine-tuned model
trainer.save_model("./distilbert-ner-tuned")

### Inference

In [16]:
def decode_predictions(predictions, id_to_label):
    predicted_labels = []
    for pred in predictions:
        label_ids = [id_to_label[label_id] for label_id in pred]
        predicted_labels.append(label_ids)
    return predicted_labels

def run_inference(sample):
    tokenized_sentence = tokenizer.tokenize(sample)
    ids = tokenizer.convert_tokens_to_ids(tokenized_sentence)
    model_input = Dataset.from_list([{'input_ids': ids}])
    predictions = trainer.predict(model_input)
    preds = predictions.predictions.argmax(-1)
    decoded_predictions = decode_predictions(preds, id_to_label)
    return tokenized_sentence, decoded_predictions

In [11]:
test_sentence = """The highest mountain on Earth is Mount Everest in the Himalayas of Asia, whose summit is 8,850 m (29,035 ft) above mean sea level.
                    The highest known mountain on any planet in the Solar System is Olympus Mons on Mars at 21,171 m (69,459 ft).
                    The tallest mountain including submarine terrain is Mauna Kea in Hawaii from its underwater base at 9,330 m (30,610 ft)
                    and some scientists consider it to be the tallest on earth."""

In [17]:
tokenized_sentence, decoded_predictions = run_inference(test_sentence)
print(f"Tokens: {tokenized_sentence}")
print(f"Predicted labels: {decoded_predictions[0]}")
print()

100%|██████████| 1/1 [00:00<00:00, 1000.07it/s]

Tokens: ['The', 'highest', 'mountain', 'on', 'Earth', 'is', 'Mount', 'Everest', 'in', 'the', 'Him', '##alaya', '##s', 'of', 'Asia', ',', 'whose', 'summit', 'is', '8', ',', '850', 'm', '(', '29', ',', '03', '##5', 'ft', ')', 'above', 'mean', 'sea', 'level', '.', 'The', 'highest', 'known', 'mountain', 'on', 'any', 'planet', 'in', 'the', 'Solar', 'System', 'is', 'O', '##ly', '##mpus', 'Mon', '##s', 'on', 'Mars', 'at', '21', ',', '171', 'm', '(', '69', ',', '45', '##9', 'ft', ')', '.', 'The', 'tallest', 'mountain', 'including', 'submarine', 'terrain', 'is', 'Ma', '##una', 'Ke', '##a', 'in', 'Hawaii', 'from', 'its', 'underwater', 'base', 'at', '9', ',', '330', 'm', '(', '30', ',', '610', 'ft', ')', 'and', 'some', 'scientists', 'consider', 'it', 'to', 'be', 'the', 'tallest', 'on', 'earth', '.']
Predicted labels: [['O', 'O', 'O', 'O', 'O', 'O', 'B-MOUNT', 'I-MOUNT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'




In [26]:
def pretty_print(tokens, labels, entities):
    # Print tokens with corresponding labels
    print("Tokens and Labels:")
    for token, label in zip(tokens, labels):
        if label == "B-MOUNT":
            print()
            print(f"{token:12} -> {label}")
        if label == "I-MOUNT":
            print(f"{token:12} -> {label}")
    
    # Print the extracted entities
    print("\nDetected Entities:\n")
    for entity in entities:
        string = tokenizer.convert_tokens_to_string(entity.split(" "))
        print(f"Entity: {string}")
    
def extract_entities(tokens, labels):
    entities = []
    current_entity = []
    
    for token, label in zip(tokens, labels):
        if label == "B-MOUNT":  # Beginning of a new entity
            if current_entity:  # Add the previous entity to the list if exists
                entities.append(" ".join(current_entity))
            current_entity = [token]  # Start a new entity
        elif label == "I-MOUNT":  # Continuation of the current entity
            current_entity.append(token)
        else:
            if current_entity:  # Add the entity to the list if exists
                entities.append(" ".join(current_entity))
                current_entity = []  # Reset the current entity
    
    # Catch any remaining entity at the end
    if current_entity:
        entities.append(" ".join(current_entity))
    
    return entities

entities = extract_entities(tokenized_sentence, decoded_predictions[0])
pretty_print(tokenized_sentence, decoded_predictions[0], entities)

Tokens and Labels:

Mount        -> B-MOUNT
Everest      -> I-MOUNT

O            -> B-MOUNT
##ly         -> I-MOUNT
##mpus       -> I-MOUNT
Mon          -> I-MOUNT
##s          -> I-MOUNT
on           -> I-MOUNT

Ma           -> B-MOUNT
##una        -> I-MOUNT
Ke           -> I-MOUNT
##a          -> I-MOUNT

Detected Entities:

Entity: Mount Everest
Entity: Olympus Mons on
Entity: Mauna Kea
