Skip to content

Commit

Permalink
add npygenericscatter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
calgray committed Jan 14, 2022
1 parent 0c1170f commit edcbbca
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
2 changes: 1 addition & 1 deletion daliuge-engine/dlg/apps/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def run(self):

##
# @brief GenericNpyScatterApp
# @details An APP that splits about any object that can be converted to a numpy array
# @details An APP that splits about any axis on any npy format data drop
# into as many parts as the app has outputs, provided that the initially converted numpy
# array has enough elements. The return will be a numpy array of arrays, where the first
# axis is of length len(outputs). The modulo remainder of the length of the original array and
Expand Down
61 changes: 55 additions & 6 deletions daliuge-engine/test/apps/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,23 @@
import time
import unittest
from multiprocessing.pool import ThreadPool
from numpy import random, mean, array, concatenate

from numpy import random, mean, array, concatenate, random, testing
from psutil import cpu_count

from dlg import droputils
from dlg.apps.simple import GenericScatterApp, SleepApp, CopyApp, SleepAndCopyApp, \
from dlg.apps.simple import (
GenericScatterApp,
GenericNpyScatterApp,
SleepApp,
CopyApp,
SleepAndCopyApp,
ListAppendThrashingApp
)
from dlg.apps.simple import RandomArrayApp, AverageArraysApp, HelloWorldApp
from dlg.ddap_protocol import DROPStates
from dlg.drop import NullDROP, InMemoryDROP, FileDROP, NgasDROP

if sys.version_info >= (3, 8):
from dlg.manager.shared_memory_manager import DlgSharedMemoryManager
from numpy import random, mean, array, concatenate
from psutil import cpu_count

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -207,6 +210,52 @@ def test_genericScatter(self):
data_out = concatenate([data1, data2])
self.assertEqual(data_in.all(), data_out.all())

def test_genericNpyScatter(self):
data_in = random.rand(100, 100)
b = InMemoryDROP("b", "b")
droputils.save_numpy(b, data_in)
s = GenericNpyScatterApp("s", "s", num_of_copies=2)
s.addInput(b)
o1 = InMemoryDROP("o1", "o1")
o2 = InMemoryDROP("o2", "o2")
for x in o1, o2:
s.addOutput(x)
self._test_graph_runs((b, s, o1, o2), b, (o1, o2), timeout=4)

data1 = droputils.load_numpy(o1)
data2 = droputils.load_numpy(o2)
data_out = concatenate([data1, data2])
self.assertEqual(data_in.all(), data_out.all())

def test_genericNpyScatter_multi(self):
data1_in = random.rand(100, 100)
data2_in = random.rand(100, 100)
b = InMemoryDROP("b", "b")
c = InMemoryDROP("c", "c")
droputils.save_numpy(b, data1_in)
droputils.save_numpy(c, data2_in)
s = GenericNpyScatterApp("s", "s", num_of_copies=2, scatter_axes="[0,0]")
s.addInput(b)
s.addInput(c)
o1 = InMemoryDROP("o1", "o1")
o2 = InMemoryDROP("o2", "o2")
o3 = InMemoryDROP("o3", "o3")
o4 = InMemoryDROP("o4", "o4")
for x in o1, o2, o3, o4:
s.addOutput(x)
self._test_graph_runs((b, s, o1, o2, o3, o4), (b, c), (o1, o2, o3, o4), timeout=4)

data11 = droputils.load_numpy(o1)
data12 = droputils.load_numpy(o2)
data1_out = concatenate([data11, data12])
self.assertEqual(data1_out.shape, data1_in.shape)
testing.assert_array_equal(data1_out, data1_in)

data21 = droputils.load_numpy(o3)
data22 = droputils.load_numpy(o4)
data2_out = concatenate([data21, data22])
testing.assert_array_equal(data2_out, data2_in)

def test_listappendthrashing(self, size=1000):
a = InMemoryDROP('a', 'a')
b = ListAppendThrashingApp('b', 'b', size=size)
Expand Down

0 comments on commit edcbbca

Please sign in to comment.