Skip to content
This repository has been archived by the owner on May 2, 2022. It is now read-only.

Commit

Permalink
add ResumableDatasetLoop (moved from fwtwirl)
Browse files Browse the repository at this point in the history
  • Loading branch information
TaiSakuma committed Feb 16, 2018
1 parent c78537a commit a1ed350
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 8 deletions.
3 changes: 1 addition & 2 deletions alphatwirl/datasetloop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .loop import DatasetLoop

from .loop import DatasetLoop, ResumableDatasetLoop
37 changes: 37 additions & 0 deletions alphatwirl/datasetloop/loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# Tai Sakuma <tai.sakuma@gmail.com>
import os
import gzip

try:
import cPickle as pickle
except:
import pickle

##__________________________________________________________________||
class DatasetLoop(object):
Expand All @@ -24,3 +31,33 @@ def __call__(self):
return self.reader.end()

##__________________________________________________________________||
class ResumableDatasetLoop(object):

def __init__(self, datasets, reader, workingarea):
self.datasets = datasets
self.reader = reader
self.workingarea = workingarea

def __repr__(self):
name_value_pairs = (
('datasets', self.datasets),
('reader', self.reader),
('workingarea', self.workingarea),
)
return '{}({})'.format(
self.__class__.__name__,
', '.join(['{}={!r}'.format(n, v) for n, v in name_value_pairs]),
)

def __call__(self):
self.reader.begin()
for dataset in self.datasets:
self.reader.read(dataset)

path = os.path.join(self.workingarea.path, 'reader.p.gz')
with gzip.open(path, 'wb') as f:
pickle.dump(self.reader, f, protocol=pickle.HIGHEST_PROTOCOL)

return self.reader.end()

##__________________________________________________________________||
28 changes: 22 additions & 6 deletions tests/unit/datasetloop/test_loop_DatasetLoop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Tai Sakuma <tai.sakuma@gmail.com>
import sys
import pytest

try:
Expand All @@ -7,21 +8,36 @@
import mock

from alphatwirl.datasetloop import DatasetLoop
from alphatwirl.datasetloop import ResumableDatasetLoop

##__________________________________________________________________||
@pytest.fixture()
def reader():
return mock.Mock()

@pytest.fixture()
def datasets():
dataset1 = mock.Mock()
dataset2 = mock.Mock()
return [dataset1, dataset2]

@pytest.fixture()
def obj(datasets, reader):
return DatasetLoop(datasets, reader)
def reader():
return mock.Mock()

@pytest.fixture()
def workingarea(monkeypatch):
module = sys.modules['alphatwirl.datasetloop.loop']
monkeypatch.setattr(module, 'os', mock.Mock())
monkeypatch.setattr(module, 'gzip', mock.MagicMock())
monkeypatch.setattr(module, 'pickle', mock.Mock())
return mock.Mock()

@pytest.fixture(
params=[0, 1],
ids=('DatasetLoop', 'ResumableDatasetLoop')
)
def obj(request, datasets, reader, workingarea):
if request.param == 0:
return DatasetLoop(datasets, reader)
else:
return ResumableDatasetLoop(datasets, reader, workingarea)

def test_repr(obj):
repr(obj)
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/datasetloop/test_loop_ResumableDatasetLoop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Tai Sakuma <tai.sakuma@gmail.com>
import os
import gzip
import pytest

try:
import cPickle as pickle
except:
import pickle

try:
import unittest.mock as mock
except ImportError:
import mock

from alphatwirl.datasetloop import ResumableDatasetLoop

##__________________________________________________________________||
class MockReader(object):
def begin(self):
pass

def read(self, dataset):
pass

def end(self):
pass

##__________________________________________________________________||

@pytest.fixture()
def reader():
ret = MockReader()
ret.original_id = id(ret)
return ret

@pytest.fixture()
def datasets():
dataset1 = mock.Mock()
dataset2 = mock.Mock()
return [dataset1, dataset2]

@pytest.fixture()
def workingarea(tmpdir_factory):
ret = mock.Mock()
ret.path = str(tmpdir_factory.mktemp(''))
return ret

@pytest.fixture()
def obj(datasets, reader, workingarea):
return ResumableDatasetLoop(datasets, reader, workingarea)

def test_repr(obj):
repr(obj)

def test_call(obj, reader, workingarea):
result = obj()
path = os.path.join(workingarea.path, 'reader.p.gz')
with gzip.open(path, 'rb') as f:
reader_unpickled = pickle.load(f)
assert reader.original_id == reader_unpickled.original_id

##__________________________________________________________________||

0 comments on commit a1ed350

Please sign in to comment.