Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/singa/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,3 +1583,14 @@ def copy_from_numpy(data, np_array):
data.CopyIntDataFromHostPtr(np_array)
else:
print('Not implemented yet for ', dt)


def concatenate(tensors, axis):
'''
concatenate tensors on given axis, all the dim should be the same
except the axis to be concatenated.
'''
ctensors = singa.VecTensor()
for t in tensors:
ctensors.append(t.data)
return _call_singa_func(singa.ConcatOn, ctensors, axis)
18 changes: 18 additions & 0 deletions test/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,24 @@ def test_dnnl_pooling_avg(self):
np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(dx0_ct)), dx1)

def test_concat(self):
np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
np2 = np.random.random([5, 6, 7, 1]).astype(np.float32)
np3 = np.concatenate((np1, np2), axis=3)

for dev in [cpu_dev, gpu_dev]:
t1 = tensor.Tensor(device=dev, data=np1)
t2 = tensor.Tensor(device=dev, data=np2)

ctensors = singa_api.VecTensor()
ctensors.append(t1.data)
ctensors.append(t2.data)

t3_ct = singa_api.ConcatOn(ctensors, 3)

np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(t3_ct)), np3)


if __name__ == '__main__':
unittest.main()
13 changes: 13 additions & 0 deletions test/python/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,19 @@ def test_transpose_then_reshape(self):
tensor.to_numpy(ta),
np.reshape(a.transpose(TRANSPOSE_AXES), RESHAPE_DIMS))

def test_concatenate(self):
np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
np2 = np.random.random([5, 6, 7, 1]).astype(np.float32)
np3 = np.concatenate((np1, np2), axis=3)

for dev in [cpu_dev, gpu_dev]:
t1 = tensor.Tensor(device=dev, data=np1)
t2 = tensor.Tensor(device=dev, data=np2)

t3 = tensor.concatenate((t1, t2), 3)

np.testing.assert_array_almost_equal(tensor.to_numpy(t3), np3)

def test_subscription_cpu(self):
np1 = np.random.random((5, 5, 5, 5)).astype(np.float32)
sg_tensor = tensor.Tensor(device=cpu_dev, data=np1)
Expand Down
2 changes: 2 additions & 0 deletions tool/travis/depends.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ conda config --add channels conda-forge
# linting
conda install -c conda-forge pylint
conda install -c conda-forge cpplint
conda install -c conda-forge deprecated
python -c "from deprecated import deprecated"