In [1]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "distilbert-base-uncased"

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [3]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [4]:
def gather_outputs(outputs: list) -> list:
    # Group entities by their sequence
    grouped_entities = []
    current_group = []
    for entity in outputs:
        if not current_group or entity['start'] == current_group[-1]['end']:
            current_group.append(entity)
        else:
            grouped_entities.append(current_group)
            current_group = [entity]

    # Append the last group
    if current_group:
        grouped_entities.append(current_group)

    return grouped_entities

def transform_sentence_from_outputs(sentence: str, outputs: list) -> list:
    groups = gather_outputs(outputs)
    locations = [{"label": group[0]["entity_group"], "city": sentence["text"][group[0]["start"]:group[-1]["end"]] } for group in groups]
    sentence = {
        "id": str(sentence["id"]),
        "locations": locations
    }

    sentence["locations"] = sorted(sentence["locations"], key=lambda group: group["label"], reverse=True)
    return sentence

def format_sentence_output(sentence_output: list) -> str:
    return f"{sentence_output['id']},{','.join([location['city'] for location in sentence_output['locations']])}"

In [5]:
from transformers import pipeline
sentences = [
    {"id": 1, "text": "Je veux aller de Port-Boulet à Le Havre."},
    {"id": 2, "text": "Peux-tu m'aider à trouver mon chemin vers Paris en partant d'Épierre ?"},
    {"id": 3, "text": "Je cherche un moyen d'aller de Margny-Lès-Compiègne à Saarbrücken /Sarrebruck."}
]

for sentence in sentences:
    token_classifier = pipeline("token-classification", model="models/distilbert-finetuned-token-classification-ner-trip", aggregation_strategy="simple")
    outputs = token_classifier(sentence["text"])
    print(format_sentence_output(transform_sentence_from_outputs(sentence, outputs)))

1,Port-Boulet,Le Havre
2,'Épierre,Paris
3,Margny-Lès-Compiègne,Saarbrücken,/Sarrebruck
