Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable pickling for vocabulary #2391

Merged
merged 2 commits into from Jan 18, 2019
Merged
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.
+47 −0
Diff settings

Always

Just for now

@@ -4,6 +4,7 @@
"""

import codecs
import copy
import logging
import os
from collections import defaultdict
@@ -232,6 +233,38 @@ def __init__(self,
tokens_to_add,
min_pretrained_embeddings)


def __getstate__(self):
"""
Need to sanitize defaultdict and defaultdict-like objects
by converting them to vanilla dicts when we pickle the vocabulary.
"""
state = copy.copy(self.__dict__)
state["_token_to_index"] = dict(state["_token_to_index"])
state["_index_to_token"] = dict(state["_index_to_token"])

if "_retained_counter" in state:
state["_retained_counter"] = {key: dict(value)
for key, value in state["_retained_counter"].items()}

return state

def __setstate__(self, state):
"""
Conversely, when we unpickle, we need to reload the plain dicts
into our special DefaultDict subclasses.
"""
# pylint: disable=attribute-defined-outside-init
self.__dict__ = copy.copy(state)
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
self._padding_token,
self._oov_token)
self._token_to_index.update(state["_token_to_index"])
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
self._padding_token,
self._oov_token)
self._index_to_token.update(state["_index_to_token"])

def save_to_files(self, directory: str) -> None:
"""
Persist this Vocabulary to files so it can be reloaded later.
@@ -1,4 +1,5 @@
import codecs
import pickle
import gzip
import zipfile
from copy import deepcopy
@@ -29,6 +30,19 @@ def setUp(self):
self.dataset = Batch([self.instance])
super(TestVocabulary, self).setUp()

def test_pickling(self):
vocab = Vocabulary.from_instances(self.dataset)

pickled = pickle.dumps(vocab)
unpickled = pickle.loads(pickled)

assert dict(unpickled._index_to_token) == dict(vocab._index_to_token)
assert dict(unpickled._token_to_index) == dict(vocab._token_to_index)
assert unpickled._non_padded_namespaces == vocab._non_padded_namespaces
assert unpickled._oov_token == vocab._oov_token
assert unpickled._padding_token == vocab._padding_token
assert unpickled._retained_counter == vocab._retained_counter

def test_from_dataset_respects_max_vocab_size_single_int(self):
max_vocab_size = 1
vocab = Vocabulary.from_instances(self.dataset, max_vocab_size=max_vocab_size)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.