Skip to content

Commit

Permalink
Adds more embedding lookup tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 21, 2020
1 parent ffbf00e commit 0f41048
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions agatha/ml/util/test_embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,55 @@ def test_setup_lookup_data():
"B": [2, 3, 4],
"C": [5, 6, 7],
}
actual = EmbeddingLookupTable(*(setup_embedding_lookup_data(
actual = EmbeddingLookupTable(*setup_embedding_lookup_data(
expected,
test_name="test_setup_lookup_data",
num_parts=1
)))
))
assert_table_contains_embeddings(actual, expected)

def test_setup_lookup_data_two_parts():
expected = {
"A": [1, 2, 3],
"B": [2, 3, 4],
"C": [5, 6, 7],
}
actual = EmbeddingLookupTable(*setup_embedding_lookup_data(
expected,
test_name="test_setup_lookup_data_two_parts",
num_parts=2
))
assert_table_contains_embeddings(actual, expected)


def test_typical_embedding_lookup():
data = {
"A": [1, 2, 3],
"B": [2, 3, 4],
"C": [5, 6, 7],
}
embeddings = EmbeddingLookupTable(*setup_embedding_lookup_data(
data,
test_name="test_typical_embedding_lookup",
num_parts=2,
))
assert "A" in embeddings
assert list(embeddings["A"]) == data["A"]

assert "D" not in embeddings


def test_embedding_keys():
data = {
"A": [1, 2, 3],
"B": [2, 3, 4],
"C": [5, 6, 7],
"D": [6, 7, 8],
"E": [7, 8, 9],
}
embeddings = EmbeddingLookupTable(*setup_embedding_lookup_data(
data,
test_name="test_embedding_keys",
num_parts=2,
))
assert set(embeddings.keys()) == set(data.keys())

0 comments on commit 0f41048

Please sign in to comment.