Skip to content

Commit

Permalink
add transpose test
Browse files Browse the repository at this point in the history
  • Loading branch information
silencrown committed Apr 28, 2023
1 parent f79281a commit bb43db6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tenseal/tensors/plaintensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,22 @@ def broadcast_(self, shape: List[int]):

def transpose(self, axes: List[int] = None):
"Copies the transpose to a new tensor"
new_tensor = None
new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype)
if axes is None:
new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype)
return new_tensor.transpose_()
elif isinstance(axes, list) and all(isinstance(x, int) for x in axes):
new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype)
return new_tensor.transpose_(axes)
else:
raise TypeError("axes must be a list of integers")
return new_tensor.transpose_()
raise TypeError("transpose axes must be a list of integers")

def transpose_(self, axes: List[int] = None):
"Tries to transpose the tensor"
if axes is None:
self.data.transpose_()
elif isinstance(axes, list) and all(isinstance(x, int) for x in axes):
self.data.transpose_(axes)
else:
raise TypeError("transpose axes must be a list of integers")
return self

@classmethod
Expand Down
28 changes: 28 additions & 0 deletions tests/cpp/tensors/ckkstensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,34 @@ TEST_P(CKKSTensorTest, TestTranspose) {
ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6}));
}

TEST_P(CKKSTensorTest, TestTransposeWithAxes) {
auto enc_type = get<1>(GetParam());

auto ctx = TenSEALContext::Create(scheme_type::ckks, 8192, -1,
{60, 40, 40, 60}, enc_type);
ASSERT_TRUE(ctx != nullptr);
ctx->generate_galois_keys();

auto ldata =
PlainTensor(vector<double>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
vector<size_t>({2, 3, 2}));

auto l = CKKSTensor::Create(ctx, ldata, std::pow(2, 40));

// Transpose with specified axes
auto res = l->transpose({0, 2, 1});
ASSERT_THAT(res->shape(), ElementsAreArray({2, 2, 3}));
ASSERT_THAT(l->shape(), ElementsAreArray({2, 3, 2}));
auto decr = res->decrypt();
ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}));

// Transpose inplace with specified axes
l->transpose_inplace({0, 2, 1});
ASSERT_THAT(l->shape(), ElementsAreArray({2, 2, 3}));
decr = l->decrypt();
ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}));
}

TEST_P(CKKSTensorTest, TestSubscript) {
auto enc_type = get<1>(GetParam());

Expand Down
24 changes: 24 additions & 0 deletions tests/python/tenseal/tensors/test_ckks_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,27 @@ def test_transpose(context, data, shape):
assert tensor.shape == list(expected.shape)
result = np.array(tensor.decrypt().tolist())
assert np.allclose(result, expected, rtol=0, atol=0.01)

@pytest.mark.parametrize(
"data, shape, axes",
[
([i for i in range(6)], [1, 2, 3], [0, 2, 1]),
([i for i in range(12)], [2, 2, 3], [0, 2, 1]),
([i for i in range(2 * 3 * 4 * 5)], [2, 3, 4, 5], [0, 3, 2, 1]),
],
)
def test_transpose_with_axes(context, data, shape, axes):
tensor = ts.ckks_tensor(context, ts.plain_tensor(data, shape))

expected = np.transpose(np.array(data).reshape(shape), axes)

newt = tensor.transpose(axes)
assert tensor.shape == shape
assert newt.shape == list(expected.shape)
result = np.array(newt.decrypt().tolist())
assert np.allclose(result, expected, rtol=0, atol=0.01)

tensor.transpose_(axes)
assert tensor.shape == list(expected.shape)
result = np.array(tensor.decrypt().tolist())
assert np.allclose(result, expected, rtol=0, atol=0.01)
22 changes: 22 additions & 0 deletions tests/python/tenseal/tensors/test_plain_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,25 @@ def test_transpose(data, shape):
tensor.transpose_()
assert tensor.shape == list(expected.shape)
assert np.array(tensor.tolist()).any() == expected.any()

@pytest.mark.parametrize(
"data, shape, axes",
[
([i for i in range(6)], [1, 2, 3], [0, 2, 1]),
([i for i in range(12)], [2, 2, 3], [0, 2, 1]),
([i for i in range(2 * 3 * 4 * 5)], [2, 3, 4, 5], [0, 3, 2, 1]),
],
)
def test_transpose(data, shape, axes):
tensor = ts.plain_tensor(data, shape)

expected = np.transpose(np.array(data).reshape(shape), axes)

newt = tensor.transpose(axes)
assert tensor.shape == shape
assert newt.shape == list(expected.shape)
assert np.array(newt.tolist()).any() == expected.any()

tensor.transpose_(axes)
assert tensor.shape == list(expected.shape)
assert np.array(tensor.tolist()).any() == expected.any()

0 comments on commit bb43db6

Please sign in to comment.