Skip to content

Commit

Permalink
fix filter and filterfalse in SentenceEmbeddingFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed May 2, 2024
1 parent b69f72d commit fb00871
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
8 changes: 7 additions & 1 deletion opusfilter/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def accept(self, score):

def filter(self, pairs):
for chunk in grouper(pairs, self.chunksize):
for pair, score in zip(pairs, self._score_chunk(chunk)):
for pair, score in zip(chunk, self._score_chunk(chunk)):
if self.accept(score):
yield pair

def filterfalse(self, pairs):
for chunk in grouper(pairs, self.chunksize):
for pair, score in zip(chunk, self._score_chunk(chunk)):
if not self.accept(score):
yield pair
36 changes: 32 additions & 4 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import tempfile
import unittest
from unittest import mock

from opusfilter import ConfigurationError
from opusfilter.embeddings import *
Expand All @@ -17,6 +18,15 @@
logging.warning("Could not load laserembeddings, LASER filtering not supported")


def mocked_score_chunk(obj, chunk):
"""Return scores for a chunk of data"""
mocked_score_chunk.counter += 1
return obj._cosine_similarities(chunk) if obj.nn_model is None else \
obj._normalized_similarities(chunk)

mocked_score_chunk.counter = 0


@unittest.skipIf('laserembeddings' not in globals(), 'laserembeddings package not installed')
class TestSentenceEmbeddingFilter(unittest.TestCase):

Expand Down Expand Up @@ -57,19 +67,37 @@ def test_train_nn_model(self):
dist, ind = nn_model.query([pair[1] for pair in self.bi_inputs], 'en', n_neighbors=2)
self.assertEqual(ind.shape, (4, 2))

def test_bilingual(self):
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4)
@mock.patch('opusfilter.embeddings.SentenceEmbeddingFilter._score_chunk', mocked_score_chunk)
def test_bilingual_score(self):
mocked_score_chunk.counter = 0
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4, chunksize=2)
expected = [True, True, False, False]
results = [testfilter.accept(x) for x in testfilter.score(self.bi_inputs)]
for result, correct in zip(results, expected):
self.assertEqual(result, correct)
self.assertEqual(mocked_score_chunk.counter, (len(self.bi_inputs) + 1) // 2)

@mock.patch('opusfilter.embeddings.SentenceEmbeddingFilter._score_chunk', mocked_score_chunk)
def test_bilingual_filter(self):
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4)
mocked_score_chunk.counter = 0
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4, chunksize=2)
expected = [self.bi_inputs[0], self.bi_inputs[1]]
results = testfilter.filter(self.bi_inputs)
for result, correct in zip(results, expected):
self.assertEqual(result, correct)
self.assertEqual(mocked_score_chunk.counter, (len(self.bi_inputs) + 1) // 2)

@mock.patch('opusfilter.embeddings.SentenceEmbeddingFilter._score_chunk', mocked_score_chunk)
def test_bilingual_filterfalse(self):
mocked_score_chunk.counter = 0
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4, chunksize=2)
expected = [self.bi_inputs[2], self.bi_inputs[3]]
logging.warning(expected)
results = list(testfilter.filterfalse(self.bi_inputs))
logging.warning(results)
for result, correct in zip(results, expected):
self.assertEqual(result, correct)
self.assertEqual(mocked_score_chunk.counter, (len(self.bi_inputs) + 1) // 2)

def test_bilingual_margin_ratios(self):
nn_model = self._train_nn_model()
Expand All @@ -82,7 +110,7 @@ def test_bilingual_margin_ratios(self):
for result, correct in zip(results, expected):
self.assertEqual(result, correct)

def test_chunking(self):
def test_pipeline_chunking(self):
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4, chunksize=19)
inputs = 50 * self.bi_inputs
expected = 50 * [True, True, False, False]
Expand Down

0 comments on commit fb00871

Please sign in to comment.