In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import random

import spacy
from edsnlp.pipelines.misc.external_model import ModelWrapper

from typing import List, Dict, Any

## 1. Defining a model

This model can be anything (PyTorch, SKLearn, etc.).  
Let's make an extremely basic sentiment analysis model

In [4]:
class SentimentAnalysisModel:
    def __init__(self):
        pass

    def batching(
        self,
        data: Dict[str, List[Any]],
        batch_size: int,
    ):
        i = 0
        n = len(data["text"])
        while i < n:
            yield {k: v[i : i + batch_size] for k, v in data.items()}
            i += batch_size

    def forward(
        self,
        batch,
    ):
        return [
            {
                "good": "good" in txt,
                "bad": "bad" in txt,
                "neutral": "neutral" in txt,
            }
            for txt in batch["text"]
        ]

    def predict(
        self,
        data: Dict[str, List[Any]],
        batch_size: int,
    ):
        preds = []
        for batch in self.batching(data, batch_size):
            preds += self.forward(batch)
        return preds


We will test our model with some toy data

In [5]:
good_texts = [
    "It was a very good movie !",
    "The cheese was pretty good.",
    "It's gonna be a very good year. Very good!",
]

bad_texts = [
    "I have a bad feeling about this",
    "This was pretty bad."
]

neutral = [
    "This is a neutral statement."
]

texts = 50*good_texts + 40*bad_texts + 30*neutral

random.shuffle(texts)

data = dict(text = texts)

In [6]:
model = SentimentAnalysisModel()
preds_from_model = model.predict(data, batch_size = 16)

In [10]:
preds_from_model[0], data["text"][0]

({'good': False, 'bad': False, 'neutral': True},
 'This is a neutral statement.')

## 2. Wrap the model

To use this model with EDS-NLP, you should wrap in by using the dedicated `ModelWrapper` class.  
Two parameters are available here:
- `span_getter`: To tell the wrapper how to generate inference data for your Model starting from a spaCy document
- `annotation_setter`: From the output predictions of your model, how do you set them on the starting spaCy Doc, Span or Token.

When creating your wrapper that inherits from `ModelWrapper`, you can either
- Use a pre-registered function for those two parameters
- Use your own by re-defining `self.span_getter` or `self.annotation_setter`

In [12]:
DEFAULT_SPAN_GETTER = {
    "@span_getters": "sentences",
}

DEFAULT_ANNOTATION_SETTER = {
    "@annotation_setters": "from-mapping",
    "mapping": {
        "good": "_.good",
        "bad": "_.bad",
        "neutral": "_.neutral",
    }
}

class SentimentAnalysisWrapper(ModelWrapper):
    def __init__(
        self,
        model: SentimentAnalysisModel,
        span_getter = DEFAULT_SPAN_GETTER,
        annotation_setter = DEFAULT_ANNOTATION_SETTER,
    ):
        super().__init__(model, span_getter, annotation_setter)

In [13]:
wrap = SentimentAnalysisWrapper(model)

Finally we will save this wrapper model:

In [14]:
wrap.to_pickle("./model.pkl")

## 3. Use the wrapped model in a pipe

Use the `eds.external-model` pipe and give the pickled model path in the configuration

In [16]:
nlp = spacy.blank("eds")

nlp.add_pipe("eds.sentences")
nlp.add_pipe("eds.external-model", config=dict(model="./model.pkl"))

<edsnlp.pipelines.misc.external_model.external_model.ExternalModel at 0x7fe4fa313650>

Now simply use `nlp.pipe`:

In [17]:
spacy_preds = list(nlp.pipe(texts))

## 4. Sanity check

Let us check that the ouput if the model and the output of `nlp.pipe` matches up:

In [18]:
spacy_preds = [
    dict(
        good = doc[0].sent._.good,
        bad = doc[0].sent._.bad,
        neutral = doc[0].sent._.neutral,
    ) for doc in spacy_preds
]

In [19]:
preds_from_model == spacy_preds

True