## Re implement ModelStage to work with Sentence Transformers

1. We create a new stage that extends `EmbeddingModelStage` and overwrites setup() to use SentenceTransformers
2. We create a new composte stage which replaces `EmbeddingCreatorStage`

In [None]:
from dataclasses import dataclass
from typing import Literal

import pandas as pd
import torch
from sentence_transformers import SentenceTransformer

from nemo_curator.backends.base import WorkerMetadata
from nemo_curator.stages.base import CompositeStage, ProcessingStage
from nemo_curator.stages.text.embedders.base import EmbeddingModelStage
from nemo_curator.stages.text.models.tokenizer import TokenizerStage
from nemo_curator.tasks import DocumentBatch


class SentenceTransformerEmbeddingModelStage(EmbeddingModelStage):
    def __init__(  # noqa: PLR0913
        self,
        model_identifier: str,
        embedding_field: str = "embeddings",
        hf_token: str | None = None,
        model_inference_batch_size: int = 1024,
        has_seq_order: bool = True,
        padding_side: Literal["left", "right"] = "right",
        autocast: bool = True,
    ):
        super().__init__(
            model_identifier=model_identifier,
            hf_token=hf_token,
            model_inference_batch_size=model_inference_batch_size,
            has_seq_order=has_seq_order,
            padding_side=padding_side,
            autocast=autocast,
        )
        # Override unpack_inference_batch to False (EmbeddingModelStage sets it to True)
        self.unpack_inference_batch = False
        self.embedding_field = embedding_field

    def outputs(self) -> tuple[list[str], list[str]]:
        return ["data"], [self.embedding_field]

    def setup(self, _: WorkerMetadata | None = None) -> None:
        """Load the model for inference."""
        self.model = SentenceTransformer(self.model_identifier, local_files_only=True)
        self.model.eval().to("cuda")

    def process_model_output(
        self,
        outputs: torch.Tensor,
        model_input_batch: dict[str, torch.Tensor] | None = None,  # noqa: ARG002
    ) -> torch.Tensor:
        return outputs["sentence_embedding"].cpu()


@dataclass(kw_only=True)
class SentenceTransformerEmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
    model_identifier: str = "sentence-transformers/all-MiniLM-L6-v2"
    text_field: str = "text"
    embedding_field: str = "embeddings"
    max_chars: int | None = None
    max_seq_length: int | None = None
    padding_side: Literal["left", "right"] = "right"
    model_inference_batch_size: int = 1024

    autocast: bool = True
    sort_by_length: bool = True
    hf_token: str | None = None

    def __post_init__(self) -> None:
        super().__init__()

        self.stages = [
            TokenizerStage(
                model_identifier=self.model_identifier,
                hf_token=self.hf_token,
                text_field=self.text_field,
                max_chars=self.max_chars,
                max_seq_length=self.max_seq_length,
                padding_side=self.padding_side,
                sort_by_length=self.sort_by_length,
            ),
            SentenceTransformerEmbeddingModelStage(
                model_identifier=self.model_identifier,
                embedding_field=self.embedding_field,
                hf_token=self.hf_token,
                model_inference_batch_size=self.model_inference_batch_size,
                has_seq_order=self.sort_by_length,
                padding_side=self.padding_side,
                autocast=self.autocast,
            ),
        ]

    def decompose(self) -> list[ProcessingStage]:
        return self.stages

## Setup the stages 

In [None]:
model_name = "google/embeddinggemma-300m"
st_composite_stage = SentenceTransformerEmbeddingCreatorStage(model_identifier=model_name)

st_tokenizer_stage = st_composite_stage.decompose()[0]
st_model_stage = st_composite_stage.decompose()[1]

st_tokenizer_stage.setup_on_node(None)
st_tokenizer_stage.setup(None)

st_model_stage.setup_on_node(None)
st_model_stage.setup(None)

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

### Run using the stage

In [None]:
from nemo_curator.tasks import DocumentBatch

input_text = [
    "绝不能放弃，世界上没有失败，只有放弃。",  # noqa: RUF001
    'is there any doubt about it "None whatsoever"',
    "세상 어떤 짐승이 이를 드러내고 사냥을 해? 약한 짐승이나 몸을 부풀리지, 진짜 짐승은 누구보다 침착하지.",
    "そのように二番目に死を偽装して生き残るようになったイタドリがどうして初めて見る自分をこんなに気遣ってくれるのかと尋ねると「私が大切にする人たちがあなたを大切にするから」と答えては",
]


dummy_batch = DocumentBatch(
    task_id="dummy_task",
    dataset_name="dummy_dataset",
    data=pd.DataFrame({"text": input_text}),
)

tokenized_output_task = st_tokenizer_stage.process(dummy_batch)
model_output_task = st_model_stage.process(tokenized_output_task)

model_output_task.data

Unnamed: 0,text,input_ids,attention_mask,embeddings
0,绝不能放弃，世界上没有失败，只有放弃。,"[2, 239306, 17055, 91435, 236900, 82255, 8939,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...","[-0.15909849107265472, 0.0327397957444191, 0.0..."
1,"is there any doubt about it ""None whatsoever""","[2, 511, 993, 1027, 9370, 1003, 625, 623, 9336...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...","[-0.17032906413078308, 0.03656821325421333, 0...."
2,"세상 어떤 짐승이 이를 드러내고 사냥을 해? 약한 짐승이나 몸을 부풀리지, 진짜 짐...","[2, 238040, 237774, 51955, 236743, 242596, 239...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-0.07146891951560974, 0.012988940812647343, 0..."
3,そのように二番目に死を偽装して生き残るようになったイタドリがどうして初めて見る自分をこんなに...,"[2, 9266, 19164, 237725, 238508, 143926, 23854...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[-0.08917465806007385, 0.03781035169959068, 0...."


### Run using raw SentenceTransformer

In [15]:
st_raw_model = SentenceTransformer(model_name)
st_raw_model.eval().to("cuda")

with torch.autocast("cuda"):
    st_raw_output = st_raw_model.encode(input_text)

### Compare Results

In [26]:
import numpy as np

np.testing.assert_allclose(np.asarray(model_output_task.data["embeddings"].tolist()), st_raw_output)