Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add spherical sampling, conversion between spherical and cartesian #661

Merged
merged 2 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/modules/kaolin.ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Tensor batching operators are in :ref:`kaolin.ops.batch`, conversions of 3D mode
:titlesonly:

kaolin.ops.batch
kaolin.ops.coords
kaolin.ops.conversions
kaolin.ops.gcn
kaolin.ops.mesh
Expand Down
1 change: 1 addition & 0 deletions kaolin/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import batch
from . import conversions
from . import coords
from . import gcn
from . import mesh
from . import random
Expand Down
61 changes: 61 additions & 0 deletions kaolin/ops/coords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division

import torch

def spherical2cartesian(azimuth, elevation, distance=None):
"""Convert spherical coordinates to cartesian.

Assuming X toward camera, Z-up and Y-right.

Args:
azimuth (torch.Tensor): azimuth in radianss.
elevation (torch.Tensor): elevation in radians.
distance (torch.Tensor or float, optional): distance. Default: 1.

Returns:
(torch.Tensor, torch.Tensor, torch.Tensor):
x, y, z, of same shape and dtype than inputs.
"""
if distance is None:
z = torch.sin(elevation)
temp = torch.cos(elevation)
else:
z = torch.sin(elevation) * distance
temp = torch.cos(elevation) * distance
x = torch.cos(azimuth) * temp
y = torch.sin(azimuth) * temp
return x, y, z

def cartesian2spherical(x, y, z):
"""Convert cartersian coordinates to spherical in radians.

Assuming X toward camera, Z-up and Y-right.

Args:
x (torch.Tensor): X components of the coordinates.
y (torch.Tensor): Y components of the coordinates.
z (torch.Tensor): Z components of the coordinates.

Returns:
(torch.Tensor, torch.Tensor, torch.Tensor):
azimuth, elevation, distance, of same shape and dtype than inputs.
"""
distance = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
elevation = torch.asin(z / distance)
azimuth = torch.atan2(y, x)
return azimuth, elevation, distance
35 changes: 33 additions & 2 deletions kaolin/ops/random.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019,20-21 NVIDIA CORPORATION & AFFILIATES.
# Copyright (c) 2019,20-21-22 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,12 +14,12 @@
# limitations under the License.

import random
import math

import numpy as np
import torch
from .spc.uint8 import uint8_to_bits


def manual_seed(torch_seed, random_seed=None, numpy_seed=None):
"""Set the seed for random and torch modules.

Expand Down Expand Up @@ -171,3 +171,34 @@ def random_spc_octrees(batch_size, max_level, device='cpu'):
octree_length += cur_nodes.shape[0]
lengths.append(octree_length)
return torch.cat(octrees, dim=0), torch.tensor(lengths, dtype=torch.torch.int32)

def sample_spherical_coords(shape,
azimuth_low=0., azimuth_high=math.pi * 2.,
elevation_low=0., elevation_high=math.pi * 0.5,
device='cpu', dtype=torch.float):
"""Sample spherical coordinates with a uniform distribution.

Args:
shape (Sequence): shape of outputs.
azimuth_low (float, optional): lower bound for azimuth, in radian. Default: 0.
azimuth_high (float, optional): higher bound for azimuth, in radian. Default: 2 * pi.
elevation_low (float, optional): lower bound for elevation, in radian. Default: 0.
elevation_high (float, optional): higher bound for elevation, in radian. Default: pi / 2.
device (torch.device, optional): device of the output tensor. Default: 'cpu'.
dtype (torch.dtype, optional): dtype of the output tensor. Default: torch.float.

Returns:
(torch.Tensor, torch.Tensor): the azimuth and elevation, both of desired ``shape``.
"""
low = torch.tensor([
[azimuth_low], [math.sin(elevation_low)]
], device=device, dtype=dtype).reshape(2, *[1 for _ in shape])
high = torch.tensor([
[azimuth_high], [math.sin(elevation_high)]
], device=device, dtype=dtype).reshape(2, *[1 for _ in shape])

rand = torch.rand([2, *shape], dtype=dtype, device=device)
inter_samples = low + rand * (high - low)
azimuth = inter_samples[0]
elevation = torch.asin(inter_samples[1])
return azimuth, elevation
15 changes: 14 additions & 1 deletion kaolin/render/camera/legacy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019,20-21 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +63,12 @@ def generate_rotate_translate_matrices(camera_position, look_at, camera_up_direc
camz_length_bx1 = camz_bx3.norm(dim=1, keepdim=True)
camz_bx3 = camz_bx3 / (camz_length_bx1 + 1e-10)

# torch.cross don't support broadcast
# (https://github.com/pytorch/pytorch/issues/39656)
if camera_up_direction.shape[0] < camz_bx3.shape[0]:
camera_up_direction = camera_up_direction.repeat(camz_bx3.shape[0], 1)
elif camera_up_direction.shape[0] > camz_bx3.shape[0]:
camz_bx3 = camz_bx3.repeat(camera_up_direction.shape[0], 1)
camx_bx3 = torch.cross(camz_bx3, camera_up_direction, dim=1)
camx_len_bx1 = camx_bx3.norm(dim=1, keepdim=True)
camx_bx3 = camx_bx3 / (camx_len_bx1 + 1e-10)
Expand Down Expand Up @@ -97,6 +104,12 @@ def generate_transformation_matrix(camera_position, look_at, camera_up_direction
"""
z_axis = (camera_position - look_at)
z_axis /= z_axis.norm(dim=1, keepdim=True)
# torch.cross don't support broadcast
# (https://github.com/pytorch/pytorch/issues/39656)
if camera_up_direction.shape[0] < z_axis.shape[0]:
camera_up_direction = camera_up_direction.repeat(z_axis.shape[0], 1)
elif z_axis.shape[0] < camera_up_direction.shape[0]:
z_axis = z_axis.repeat(camera_up_direction.shape[0], 1)
x_axis = torch.cross(camera_up_direction, z_axis, dim=1)
x_axis /= x_axis.norm(dim=1, keepdim=True)
y_axis = torch.cross(z_axis, x_axis, dim=1)
Expand Down
3 changes: 1 addition & 2 deletions kaolin/render/mesh/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019,20-21 NVIDIA CORPORATION & AFFILIATES.
# Copyright (c) 2019,20-21-22 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,7 +16,6 @@
from __future__ import division

import torch
import torch.nn

from .. import camera
from ... import ops
Expand Down
98 changes: 98 additions & 0 deletions tests/python/kaolin/ops/test_coords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import math
import torch

from kaolin.utils.testing import FLOAT_TYPES, check_tensor
from kaolin.ops.coords import cartesian2spherical, spherical2cartesian

@pytest.mark.parametrize('device, dtype', FLOAT_TYPES)
class TestCartesian2Spherical:
@pytest.fixture(autouse=True)
def coords(self, device, dtype):
coords = torch.rand((11, 7, 3), device=device, dtype=dtype) * 10. - 5.
return {
'x': coords[..., 0],
'y': coords[..., 1],
'z': coords[..., 2]
}

def test_cartesian2spherical(self, coords, dtype):
x = coords['x']
y = coords['y']
z = coords['z']

azimuth, elevation, distance = cartesian2spherical(x, y, z)
# This is pretty much how it is currently implemented in the function
# but this is very simple
expected_distance = torch.sqrt(
x ** 2 + y ** 2 + z ** 2)
expected_elevation = torch.asin(z / distance)
expected_azimuth = torch.atan2(y, z)
assert torch.allclose(azimuth, expected_azimuth)
assert torch.allclose(elevation, expected_elevation)
assert torch.allclose(distance, expected_distance)

def test_cartesian2spherical2cartesian(self, coords):
x = coords['x']
y = coords['y']
z = coords['z']

azimuth, elevation, distance = cartesian2spherical(x, y, z)
out_x, out_y, out_z = spherical2cartesian(azimuth, elevation, distance)
assert torch.allclose(x, out_x)
assert torch.allclose(y, out_y)
assert torch.allclose(z, out_z)

@pytest.mark.parametrize('device, dtype', FLOAT_TYPES)
class TestCartesian2Spherical:
@pytest.fixture(autouse=True)
def coords(self, device, dtype):
# Not uniform but good enough
return {
'azimuth': (torch.rand((11, 7), device=device, dtype=dtype) * 2. - 1.) * math.pi,
'elevation': (torch.rand((11, 7), device=device, dtype=dtype) - 0.5) * math.pi,
'distance': torch.rand((11, 7), device=device, dtype=dtype) * 10. + 0.1
}

def test_spherical2cartesian(self, coords, dtype):
azimuth = coords['azimuth']
elevation = coords['elevation']
distance = coords['distance']

x, y, z = spherical2cartesian(azimuth, elevation, distance)
# This is pretty much how it is currently implemented in the function
# but this is very simple
expected_z = torch.sin(elevation) * distance
temp = torch.cos(elevation) * distance
expected_x = torch.cos(azimuth) * temp
expected_y = torch.sin(azimuth) * temp
assert torch.equal(x, expected_x)
assert torch.equal(y, expected_y)
assert torch.equal(z, expected_z)

def test_spherical2cartesian2spherical(self, coords):
azimuth = coords['azimuth']
elevation = coords['elevation']
distance = coords['distance']

x, y, z = spherical2cartesian(azimuth, elevation, distance)
out_azimuth, out_elevation, out_distance = cartesian2spherical(
x, y, z)
assert torch.allclose(azimuth, out_azimuth, rtol=1e-3, atol=1e-3)
assert torch.allclose(elevation, out_elevation, rtol=1e-1, atol=1e-1)
assert torch.allclose(distance, out_distance, rtol=1e-3, atol=1e-3)
55 changes: 32 additions & 23 deletions tests/python/kaolin/ops/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,17 @@
import pytest
import torch

from kaolin.ops.random import random_shape_per_tensor, random_tensor, \
random_spc_octrees, manual_seed
from kaolin.utils.testing import BOOL_TYPES, NUM_TYPES, check_tensor, \
check_spc_octrees

import kaolin as kal
from kaolin.utils.testing import BOOL_TYPES, NUM_TYPES, FLOAT_TYPES, \
check_tensor, check_spc_octrees

@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("min_shape,max_shape",
[(None, (3, 3)), ((5, 5), (5, 5))])
def test_random_shape_per_tensor(batch_size, min_shape, max_shape):
old_seed = torch.initial_seed()
torch.manual_seed(0)
shape_per_tensor = random_shape_per_tensor(batch_size, min_shape, max_shape)
shape_per_tensor = kal.ops.random.random_shape_per_tensor(batch_size, min_shape, max_shape)
if min_shape is None:
min_shape = tuple([1] * len(max_shape))
min_shape = torch.tensor(min_shape).unsqueeze(0)
Expand All @@ -42,18 +40,18 @@ def test_random_shape_per_tensor(batch_size, min_shape, max_shape):
@pytest.mark.parametrize("min_shape,max_shape", [((5, 5, 5), (30, 30, 30))])
def test_random_shape_per_tensor_seed(batch_size, min_shape, max_shape):
threshold = batch_size * len(max_shape) * 0.9
manual_seed(0)
shape_per_tensor1 = random_shape_per_tensor(batch_size, min_shape,
kal.ops.random.manual_seed(0)
shape_per_tensor1 = kal.ops.random.random_shape_per_tensor(batch_size, min_shape,
max_shape)
shape_per_tensor2 = random_shape_per_tensor(batch_size, min_shape,
shape_per_tensor2 = kal.ops.random.random_shape_per_tensor(batch_size, min_shape,
max_shape)
assert torch.sum(shape_per_tensor1 != shape_per_tensor2) > threshold
manual_seed(0)
shape_per_tensor3 = random_shape_per_tensor(batch_size, min_shape,
kal.ops.random.manual_seed(0)
shape_per_tensor3 = kal.ops.random.random_shape_per_tensor(batch_size, min_shape,
max_shape)
assert torch.equal(shape_per_tensor1, shape_per_tensor3)
manual_seed(1)
shape_per_tensor4 = random_shape_per_tensor(batch_size, min_shape,
kal.ops.random.manual_seed(1)
shape_per_tensor4 = kal.ops.random.random_shape_per_tensor(batch_size, min_shape,
max_shape)
assert torch.sum(shape_per_tensor1 != shape_per_tensor4) > threshold

Expand All @@ -62,7 +60,7 @@ def test_random_shape_per_tensor_seed(batch_size, min_shape, max_shape):
@pytest.mark.parametrize("low,high", [(0, 1), (3, 5), (10, 10)])
@pytest.mark.parametrize("shape", [(1,), (3, 3)])
def test_random_tensor(low, high, shape, dtype, device):
tensor = random_tensor(low, high, shape, dtype, device)
tensor = kal.ops.random.random_tensor(low, high, shape, dtype, device)
check_tensor(tensor, shape, dtype, device)
assert (low <= tensor).all()
assert (tensor <= high).all()
Expand All @@ -72,7 +70,7 @@ def test_random_tensor(low, high, shape, dtype, device):
@pytest.mark.parametrize("low,high", [(0, 1)])
@pytest.mark.parametrize("shape", [(1,), (3, 3)])
def test_random_tensor(low, high, shape, dtype, device):
tensor = random_tensor(low, high, shape, dtype, device)
tensor = kal.ops.random.random_tensor(low, high, shape, dtype, device)
check_tensor(tensor, shape, dtype, device)
assert (low <= tensor).all()
assert (tensor <= high).all()
Expand All @@ -82,20 +80,31 @@ def test_random_tensor(low, high, shape, dtype, device):
@pytest.mark.parametrize("shape", [(10, 10)])
def test_random_tensor_seed(low, high, shape):
threshold = shape[0] * shape[1] * 0.9
manual_seed(0)
tensor1 = random_tensor(low, high, shape)
tensor2 = random_tensor(low, high, shape)
kal.ops.random.manual_seed(0)
tensor1 = kal.ops.random.random_tensor(low, high, shape)
tensor2 = kal.ops.random.random_tensor(low, high, shape)
assert torch.sum(tensor1 != tensor2) > threshold
manual_seed(0)
tensor3 = random_tensor(low, high, shape)
kal.ops.random.manual_seed(0)
tensor3 = kal.ops.random.random_tensor(low, high, shape)
assert torch.equal(tensor1, tensor3)
manual_seed(1)
tensor4 = random_tensor(low, high, shape)
kal.ops.random.manual_seed(1)
tensor4 = kal.ops.random.random_tensor(low, high, shape)
assert torch.sum(tensor1 != tensor4) > threshold

@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("level", [1, 3])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_random_spc_octree(batch_size, level, device):
octrees, lengths = random_spc_octrees(batch_size, level, device)
octrees, lengths = kal.ops.random.random_spc_octrees(batch_size, level, device)
check_spc_octrees(octrees, lengths, batch_size, level, device)

@pytest.mark.parametrize("device,dtype", FLOAT_TYPES)
def test_sample_spherical_coords(device, dtype):
azimuth, elevation = kal.ops.random.sample_spherical_coords(
(11, 7), azimuth_low=0.1, azimuth_high=0.3,
elevation_low=0.3, elevation_high=0.6, device=device, dtype=dtype
)
check_tensor(azimuth, (11, 7), dtype, device)
check_tensor(elevation, (11, 7), dtype, device)
assert torch.all(azimuth >= 0.1) and torch.all(azimuth <= 0.3)
assert torch.all(elevation >= 0.3) and torch.all(elevation <= 0.6)
1 change: 1 addition & 0 deletions tests/python/kaolin/render/mesh/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import pytest
import math
import torch
from kaolin.utils.testing import FLOAT_TYPES, check_tensor
from kaolin.render.mesh.utils import texture_mapping
Expand Down