Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sentence embedding field #158

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
739d3f2
Added per-field custom datatype support
ivansmokovic Feb 19, 2020
7c6740c
WIP: TfIdfVectorizer update pending
ivansmokovic Feb 20, 2020
5532206
Added option to define custom missing data symbol
ivansmokovic Feb 24, 2020
280cc43
Optimized handling od missing value rows in Iterator
ivansmokovic Feb 24, 2020
2f505b2
Made TfIdf vectorizer not support fields with missing data
ivansmokovic Feb 27, 2020
8596aec
Merge branches 'master' and 'missing-data-token' of github.com:FilipB…
ivansmokovic Feb 27, 2020
6080172
Merge remote-tracking branch 'origin/master' into missing-data-token
ivansmokovic Feb 28, 2020
cb833dc
Added missing data support to subclasses of Field
ivansmokovic Feb 28, 2020
0c20b24
Merge branches 'master' and 'missing-data-token' of github.com:FilipB…
ivansmokovic Mar 18, 2020
00dc2fc
Fixed a test
ivansmokovic Mar 18, 2020
c4a1184
Added custom padding token to field for use with custom_numericalize
ivansmokovic Mar 18, 2020
a51b3b4
flake8
ivansmokovic Mar 18, 2020
5bff1b0
WIP, testing
ivansmokovic Mar 18, 2020
feeac4e
flake8
ivansmokovic Mar 18, 2020
c119625
Added documentation
ivansmokovic Mar 18, 2020
c5fa93e
Added per-field custom datatype support
ivansmokovic Feb 19, 2020
3a7efc3
WIP: TfIdfVectorizer update pending
ivansmokovic Feb 20, 2020
318a36b
Added option to define custom missing data symbol
ivansmokovic Feb 24, 2020
41b1aa7
Made TfIdf vectorizer not support fields with missing data
ivansmokovic Feb 27, 2020
dcaba7c
Fixed a test
ivansmokovic Mar 18, 2020
09cbfb6
Added custom padding token to field for use with custom_numericalize
ivansmokovic Mar 18, 2020
39f322f
flake8
ivansmokovic Mar 18, 2020
56fb576
WIP, testing
ivansmokovic Mar 18, 2020
f914f8f
flake8
ivansmokovic Mar 18, 2020
37c0ed5
Added documentation
ivansmokovic Mar 18, 2020
19141e0
rebased to master
ivansmokovic Apr 2, 2020
a2b0363
Merge remote-tracking branch 'origin/SentenceEmbeddingField' into Sen…
ivansmokovic Apr 2, 2020
c67546f
Added language, vocab, is_target, allow_missing_data
ivansmokovic Apr 17, 2020
66592d6
Merge branch 'master' of github.com:FilipBolt/takepod into SentenceEm…
ivansmokovic Apr 17, 2020
17980ba
merged master
ivansmokovic Apr 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions podium/datasets/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def _create_batch(self, examples):
for field in self._dataset.fields:
if field.is_numericalizable and field.batch_as_matrix:
# If this field is numericalizable, generate a possibly padded matrix

# the length to which all the rows are padded (or truncated)
pad_length = Iterator._get_pad_length(field, examples)

Expand All @@ -268,7 +267,7 @@ def _create_batch(self, examples):
matrix = None # np.empty(shape=(n_rows, pad_length))

# non-sequential fields all have length = 1, no padding necessary
should_pad = True if field.is_sequential else False
should_pad = field.is_sequential

for i, example in enumerate(examples):

Expand Down Expand Up @@ -321,14 +320,14 @@ def _create_batch(self, examples):

@staticmethod
def _get_pad_length(field, examples):
if not field.is_sequential:
return 1

# the fixed_length attribute of Field has priority over the max length
# of all the examples in the batch
if field.fixed_length is not None:
return field.fixed_length

if not field.is_sequential:
return 1

# if fixed_length is None, then return the maximum length of all the
# examples in the batch
def length_of_field(example):
Expand Down
4 changes: 2 additions & 2 deletions podium/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .example_factory import ExampleFactory, ExampleFormat
from .field import Field, TokenizedField, MultilabelField, MultioutputField, \
unpack_fields, LabelField
unpack_fields, LabelField, SentenceEmbeddingField
from .resources.downloader import (BaseDownloader, SCPDownloader, HttpDownloader,
SimpleHttpDownloader)
from .resources.large_resource import LargeResource, SCPLargeResource
Expand All @@ -21,6 +21,6 @@

__all__ = ["BaseDownloader", "SCPDownloader", "HttpDownloader", "SimpleHttpDownloader",
"Field", "TokenizedField", "LabelField", "MultilabelField", "MultioutputField",
"unpack_fields", "LargeResource", "SCPLargeResource",
"unpack_fields", "LargeResource", "SCPLargeResource", "SentenceEmbeddingField",
"VectorStorage", "BasicVectorStorage", "SpecialVocabSymbols", "Vocab",
"ExampleFactory", "ExampleFormat", "TfIdfVectorizer"]
59 changes: 52 additions & 7 deletions podium/storage/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import itertools
from collections import deque
from typing import Callable

import numpy as np

Expand Down Expand Up @@ -249,13 +250,13 @@ def __init__(self,
If true, the output of the tokenizer is presumed to be a list of tokens and
will be numericalized using the provided Vocab or custom_numericalize.
For numericalizable fields, Iterator will generate batch fields containing
numpy matrices.
numpy matrices.

If false, the out of the tokenizer is presumed to be a custom datatype.
Posttokenization hooks aren't allowed to be added as they can't be called
on custom datatypes. For non-numericalizable fields, Iterator will generate
batch fields containing lists of these custom data type instances returned
by the tokenizer.
If false, the out of the tokenizer is presumed to be a custom datatype.
Posttokenization hooks aren't allowed to be added as they can't be called
on custom datatypes. For non-numericalizable fields, Iterator will generate
batch fields containing lists of these custom data type instances returned
by the tokenizer.
custom_numericalize : callable
The numericalization function that will be called if the field
doesn't use a vocabulary. If using custom_numericalize and padding is
Expand Down Expand Up @@ -666,7 +667,7 @@ def numericalize(self, data):
_LOGGER.error(error_msg)
raise ValueError(error_msg)

else:
elif not self.custom_numericalize:
return None

# raw data is just a string, so we need to wrap it into an iterable
Expand Down Expand Up @@ -1005,6 +1006,50 @@ def _numericalize_tokens(self, tokens):
return numericalize_multihot(tokens, token_numericalize, self.num_of_classes)


class SentenceEmbeddingField(Field):
"""Field used for sentence-level multidimensional embeddings."""

def __init__(self,
name: str,
embedding_fn: Callable[[str], np.array],
embedding_size: int,
vocab=None,
is_target=False,
language='en',
allow_missing_data=False):
"""
Field used for sentence-level multidimensional embeddings.

Parameters
----------
name: str
Field name, used for referencing data in the dataset.
embedding_fn: Callable[[str], np.array]
Callable that takes a string and returns a fixed-width embedding.
In case of missing data, this callable will be passed a None.
embedding_size: int
Width of the embedding.
vocab: Vocab
Vocab that will be updated with the sentences passed to this field.
Keep in mind that whole sentences will be passed to the vocab.
language: str
Langage of the data. Not used in this field.
allow_missing_data: bool
Whether this field will allow the processing of missing data.
"""
super().__init__(name,
custom_numericalize=embedding_fn,
tokenizer=None,
language=language,
vocab=vocab,
tokenize=False,
store_as_raw=True,
store_as_tokenized=False,
is_target=is_target,
fixed_length=embedding_size,
allow_missing_data=allow_missing_data)


def numericalize_multihot(tokens, token_indexer, num_of_classes):
active_classes = list(map(token_indexer, tokens))
multihot_encoding = np.zeros(num_of_classes, dtype=np.bool)
Expand Down
46 changes: 23 additions & 23 deletions test/storage/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mock import patch

from podium.storage import Field, TokenizedField, MultilabelField, \
Vocab, SpecialVocabSymbols, MultioutputField, LabelField
Vocab, SpecialVocabSymbols, MultioutputField, LabelField, SentenceEmbeddingField

ONE_TO_FIVE = [1, 2, 3, 4, 5]

Expand Down Expand Up @@ -689,36 +689,14 @@ def test_missing_values_default_sequential():
custom_numericalize=lambda x: hash(x),
allow_missing_data=True)

_, data_missing = fld.preprocess(None)[0]
_, data_exists = fld.preprocess("data_string")[0]

assert data_missing == (None, None)
assert data_exists == (None, ["data_string"])
fld.finalize()

assert fld.numericalize(data_missing) is None
assert np.all(fld.numericalize(data_exists) == np.array([hash("data_string")]))


def test_missing_values_custom_numericalize():
fld = Field(name="test_field",
store_as_raw=True,
tokenize=False,
custom_numericalize=int,
allow_missing_data=True)

_, data_missing = fld.preprocess(None)[0]
_, data_exists = fld.preprocess("404")[0]

assert data_missing == (None, None)
assert data_exists == ("404", None)

fld.finalize()

assert fld.numericalize(data_missing) is None
assert np.all(fld.numericalize(data_exists) == np.array([404]))


def test_missing_symbol_index_vocab():
vocab = Vocab()
fld = Field(name="test_field",
Expand Down Expand Up @@ -874,3 +852,25 @@ def test_label_field():
_, example = x[0]
raw, _ = example
assert label_field.numericalize(example) == vocab.stoi[raw]


def test_sentence_embedding_field():
ivansmokovic marked this conversation as resolved.
Show resolved Hide resolved
def mock_embedding_fn(sentence):
if sentence == "test_sentence":
return np.array([1, 2, 3, 4])

if sentence is None:
return np.zeros(4)

field = SentenceEmbeddingField("test_field",
embedding_fn=mock_embedding_fn,
embedding_size=4,
allow_missing_data=True)

(_, data), = field.preprocess("test_sentence")
numericalization_1 = field.numericalize(data)
assert np.all(numericalization_1 == np.array([1, 2, 3, 4]))

(_, data), = field.preprocess(None)
numericalization_2 = field.numericalize(data)
assert np.all(numericalization_2 == np.zeros(4))