Skip to content

Commit

Permalink
feat: example text embedding functions
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse committed Feb 3, 2023
1 parent 89e5542 commit f4ce9a2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
24 changes: 24 additions & 0 deletions src/timeseriesflattener/text_embedding_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pandas as pd
import torch
from pandas import Series
from sentence_transformers import SentenceTransformer
from transformers import pipeline


def huggingface_embedding(text_series: pd.Series, model_name: str) -> pd.DataFrame:
"""Embeds the text data using a huggingface model. To use this
in timeseriesflattener, you need to write a wrapper with your desired model name."""
extractor = pipeline(model=model_name, task="feature-extraction")
embeddings = extractor(text_series.to_list(), return_tensors=True)
embeddings = [torch.mean(embedding, dim=1).squeeze() for embedding in embeddings]
return pd.DataFrame(embeddings).astype(float)


def sentence_transformers_embedding(
text_series: pd.Series, model_name: str
) -> pd.DataFrame:
"""Embeds the text data using a sentence-transformers model. To use this
in timeseriesflattener, you need to write a wrapper with your desired model name."""
model = SentenceTransformer(model_name)
embeddings = model.encode(text_series.to_list())
return pd.DataFrame(embeddings)
31 changes: 28 additions & 3 deletions tests/test_timeseriesflattener/test_embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,36 @@

from timeseriesflattener.testing.text_embedding_functions import bow_test_embedding
from timeseriesflattener.testing.utils_for_testing import synth_text_data
from timeseriesflattener.text_embedding_functions import (
huggingface_embedding,
sentence_transformers_embedding,
)


def test_embedding_fn(synth_text_data):
"""Test that the embedding function works as expected"""
"""Test that synth embedding function works as expected"""
df = synth_text_data.dropna(subset="text")
embedding_fn = bow_test_embedding
embedding_df = embedding_fn(df["text"])
embedding_df = bow_test_embedding(df["text"])
assert embedding_df.shape == (df.shape[0], 10)


def test_huggingface_embedding(synth_text_data):
"""Test that the huggingface embedding function works as expected"""
df = synth_text_data.dropna(subset="text")
df = df.head(5)
embedding_df = huggingface_embedding(
df["text"],
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
)
assert embedding_df.shape == (df.shape[0], 384)


def test_sentence_transformer_embedding(synth_text_data):
"""Test that the sentence-transformer embedding function works as expected"""
df = synth_text_data.dropna(subset="text")
df = df.head(5)
embedding_df = sentence_transformers_embedding(
df["text"],
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
)
assert embedding_df.shape == (df.shape[0], 384)

0 comments on commit f4ce9a2

Please sign in to comment.