-
Notifications
You must be signed in to change notification settings - Fork 212
/
scene_object.py
424 lines (357 loc) · 15 KB
/
scene_object.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""
Class representing objects in the scene
"""
import tensorflow as tf
from .object import Object
from .radio_material import RadioMaterial
import drjit as dr
import mitsuba as mi
from .utils import mi_to_tf_tensor, angles_to_mitsuba_rotation, normalize,\
theta_phi_from_unit_vec
from sionna.constants import PI
class SceneObject(Object):
# pylint: disable=line-too-long
r"""
SceneObject()
Every object in the scene is implemented by an instance of this class
"""
def __init__(self,
name,
object_id,
scene,
mi_shape,
radio_material=None):
# Initialize the base class Object
super().__init__(name)
# Set the radio material
self.radio_material = radio_material
# Set the object id
self._object_id = object_id
# Scene
self._scene = scene
# Set the Mitsuba shape
self._mi_shape = mi_shape
# Set velocity vector
self.velocity = tf.cast([0,0,0], dtype=scene.dtype.real_dtype)
# Orientation of the object is initialized to (0,0,0)
self._orientation = tf.cast([0.,0.,0.], dtype=scene.dtype.real_dtype)
if scene.dtype == tf.complex64:
self._mi_point_t = mi.Point3f
self._mi_vec_t = mi.Vector3f
self._mi_scalar_t = mi.Float
self._mi_transform_t = mi.Transform4f
else:
self._mi_point_t = mi.Point3d
self._mi_vec_t = mi.Vector3d
self._mi_scalar_t = mi.Float64
self._mi_transform_t = mi.Transform4d
@property
def object_id(self):
r"""
int : Return the identifier of this object
"""
return self._object_id
@property
def radio_material(self):
r"""
:class:`~sionna.rt.RadioMaterial` : Get/set the radio material of the
object. Setting can be done by using either an instance of
:class:`~sionna.rt.RadioMaterial` or the material name (`str`).
If the radio material is not part of the scene, it will be added. This
can raise an error if a different radio material with the same name was
already added to the scene.
"""
return self._radio_material
@radio_material.setter
def radio_material(self, mat):
# Note: _radio_material is set at __init__, but pylint doesn't see it.
if mat is None:
mat_obj = None
elif isinstance(mat, str):
mat_obj = self.scene.get(mat)
if (mat_obj is None) or (not isinstance(mat_obj, RadioMaterial)):
err_msg = f"Unknown radio material '{mat}'"
raise TypeError(err_msg)
elif not isinstance(mat, RadioMaterial):
err_msg = ("The material must be a material name (str) or an "
"instance of RadioMaterial")
raise TypeError(err_msg)
else:
mat_obj = mat
# Remove the object from the set of the currently used material, if any
# pylint: disable=access-member-before-definition
if hasattr(self, '_radio_material') and self._radio_material:
self._radio_material.discard_object_using(self.object_id)
# Assign the new material
# pylint: disable=access-member-before-definition
self._radio_material = mat_obj
# If the radio material is set to None, we can stop here
# pylint: disable=access-member-before-definition
if not self._radio_material:
return
# Add the object to the set of the newly used material
# pylint: disable=access-member-before-definition
self._radio_material.add_object_using(self.object_id)
# Add the RadioMaterial to the scene if not already done
self.scene.add(self._radio_material)
@property
def velocity(self):
"""
[3], tf.float : Get/set the velocity vector [m/s]
"""
return self._velocity
@velocity.setter
def velocity(self, v):
if not tf.shape(v)==3:
raise ValueError("`velocity` must have shape [3]")
self._velocity = tf.cast(v, self._scene.dtype.real_dtype)
@property
def position(self):
"""
[3], tf.float : Get/set the position vector [m] of the center
of the object. The center is defined as the object's axis-aligned
bounding box (AABB).
"""
rdtype = self._scene.dtype.real_dtype
# Bounding box
# [3]
bbox_min = mi_to_tf_tensor(self._mi_shape.bbox().min, rdtype)
# [3]
bbox_max = mi_to_tf_tensor(self._mi_shape.bbox().max, rdtype)
# [3]
half = tf.cast(0.5, rdtype)
position = half*(bbox_min + bbox_max)
return position
@position.setter
def position(self, new_position):
## Update Mitsuba vertices
# Scene parameters
scene_params = self._scene.mi_scene_params
# Real dtype
rdtype = self._scene.dtype.real_dtype
new_position = tf.cast(new_position, rdtype)
# [num_vertices*3]
vertices = scene_params[f'mesh-{self.name}.vertex_positions']
# [num_vertices,3]
vertices = mi_to_tf_tensor(vertices, rdtype)
vertices = tf.reshape(vertices, [-1, 3])
# [3]
position = self.position
# [3]
translation_vector = new_position - position
# [1,3]
translation_vector = tf.expand_dims(translation_vector, axis=0)
# [num_vertices,3]
translated_vertices = vertices + translation_vector
# Cast to Mitsuba type to object the Mitsuba scene
fltn_translated_vertices = tf.reshape(translated_vertices, [-1])
fltn_translated_vertices = self._mi_scalar_t(fltn_translated_vertices)
#
scene_params[f'mesh-{self.name}.vertex_positions'] =\
fltn_translated_vertices
scene_params.update()
## Update Sionna vertices
obj_id = self.object_id
mi_shape = self._mi_shape
solver_paths = self._scene.solver_paths
shape_ind = solver_paths.shape_indices[obj_id]
prim_offset = solver_paths.prim_offsets[shape_ind]
face_indices3 = mi_shape.face_indices(dr.arange(mi.UInt32,
mi_shape.face_count()))
# Flatten. This is required for calling vertex_position
# [n_prims*3]
face_indices = dr.ravel(face_indices3)
# Get vertices coordinates
# [n_prims*3, 3]
vertex_coords = mi_shape.vertex_position(face_indices)
# Cast to TensorFlow type
# [n_prims*3, 3]
vertex_coords = mi_to_tf_tensor(vertex_coords, rdtype)
# Unflatten
# [n_prims, vertices per triangle : 3, 3]
vertex_coords = tf.reshape(vertex_coords, [mi_shape.face_count(), 3, 3])
# Update the tensor storing the primitive vertices
sl = tf.range(prim_offset, prim_offset + mi_shape.face_count(),
dtype=tf.int32)
sl = tf.expand_dims(sl, axis=1)
solver_paths.primitives.scatter_nd_update(sl, vertex_coords)
## Update Sionna wedges
wedges_objects = solver_paths.wedges_objects
wedges_origin = solver_paths.wedges_origin
# Indices of the wedges corresponding to this object
# [num_wedges]
wedges_ind, _ = tf.unique(tf.where(wedges_objects == obj_id)[:,0])
# Corresponding origins
# [num_wedges, 3]
wedges_origin = tf.gather(wedges_origin, wedges_ind, axis=0)
# Translates the wedges
# [num_wedges, 3]
wedges_origin += translation_vector
# Updates the wedges
wedges_ind = tf.expand_dims(wedges_ind, axis=1)
solver_paths.wedges_origin.scatter_nd_update(wedges_ind, wedges_origin)
# Trigger scene callback
self._scene.scene_geometry_updated()
@property
def orientation(self):
r"""
[3], tf.float : Get/set the orientation :math:`(\alpha, \beta, \gamma)`
[rad] specified through three angles corresponding to a
3D rotation as defined in :eq:`rotation`.
"""
return self._orientation
@orientation.setter
def orientation(self, new_orient):
# Real dtype
rdtype = self._scene.dtype.real_dtype
new_orient = tf.cast(new_orient, rdtype)
# Build the transformtation corresponding to the new rotation
new_rotation = angles_to_mitsuba_rotation(new_orient)
# Invert the current orientation
cur_rotation = angles_to_mitsuba_rotation(self._orientation.numpy())
inv_cur_rotation = cur_rotation.inverse()
# Build the transform.
# The object is first translated to the origin, then rotated, then
# translated back to its current position
transform = ( self._mi_transform_t.translate(self.position.numpy())
@ new_rotation
@ inv_cur_rotation
@ self._mi_transform_t.translate(-self.position.numpy()) )
## Update Mitsuba vertices
# Scene parameters
scene_params = self._scene.mi_scene_params
# [num_vertices*3]
vertices = scene_params[f'mesh-{self.name}.vertex_positions']
# [num_vertices,3]
vertices = dr.unravel(self._mi_point_t, vertices)
# Apply the transform
vertices = transform.transform_affine(vertices)
# Cast to Mitsuba type to object the Mitsuba scene
fltn_vertices = tf.reshape(vertices, [-1])
fltn_vertices = tf.cast(fltn_vertices, tf.float32)
scene_params[f'mesh-{self.name}.vertex_positions'] = fltn_vertices
scene_params.update()
## Update Sionna vertices
obj_id = self.object_id
mi_shape = self._mi_shape
solver_paths = self._scene.solver_paths
shape_ind = solver_paths.shape_indices[obj_id]
prim_offset = solver_paths.prim_offsets[shape_ind]
face_indices3 = mi_shape.face_indices(dr.arange(mi.UInt32,
mi_shape.face_count()))
# Flatten. This is required for calling vertex_position
# [n_prims*3]
face_indices = dr.ravel(face_indices3)
# Get vertices coordinates
# [n_prims*3, 3]
vertex_coords = mi_shape.vertex_position(face_indices)
# Cast to TensorFlow type
# [n_prims*3, 3]
vertex_coords = mi_to_tf_tensor(vertex_coords, rdtype)
# Unflatten
# [n_prims, vertices per triangle : 3, 3]
vertex_coords = tf.reshape(vertex_coords, [mi_shape.face_count(), 3, 3])
# Update the tensor storing the primitive vertices
sl = tf.range(prim_offset, prim_offset + mi_shape.face_count(),
dtype=tf.int32)
sl = tf.expand_dims(sl, axis=1)
solver_paths.primitives.scatter_nd_update(sl, vertex_coords)
## Update Sionna normals
# Get vertices coordinates
# [n_prims, 3]
normals = solver_paths.normals.gather_nd(sl)
# Cast to Mitsuba Vector
# [n_prims, 3]
normals = self._mi_vec_t(normals)
# Rotate the normals
normals = transform.transform_affine(normals)
# Cast to Tensorflow type
# [n_prims, 3]
normals = mi_to_tf_tensor(normals, rdtype)
# Update the tensor storing the primitive vertices
solver_paths.normals.scatter_nd_update(sl, normals)
## Update Sionna wedges
wedges_objects = solver_paths.wedges_objects
wedges_origin = solver_paths.wedges_origin
wedges_e_hat = solver_paths.wedges_e_hat
wedges_normals = solver_paths.wedges_normals
# Indices of the wedges corresponding to this object
# [num_wedges]
wedges_ind, _ = tf.unique(tf.where(wedges_objects == obj_id)[:,0])
# Corresponding origins, e_hat, and normals
# [num_wedges, 3]
wedges_origin = tf.gather(wedges_origin, wedges_ind, axis=0)
# [num_wedges, 3]
wedges_e_hat = tf.gather(wedges_e_hat, wedges_ind, axis=0)
# [num_wedges, 3]
wedges_normals = tf.gather(wedges_normals, wedges_ind, axis=0)
# [num_wedges*2, 3]
wedges_normals = tf.reshape(wedges_normals, [-1, 3])
# Cast to Mitsuba types
# [num_wedges, 3]
wedges_origin = self._mi_point_t(wedges_origin)
# [num_wedges, 3]
wedges_e_hat = self._mi_vec_t(wedges_e_hat)
# [num_wedges*2, 3]
wedges_normals = self._mi_vec_t(wedges_normals)
# Rotate all quantities
# [num_wedges, 3]
wedges_origin = transform.transform_affine(wedges_origin)
# [num_wedges, 3]
wedges_e_hat = transform.transform_affine(wedges_e_hat)
# [num_wedges*2, 3]
wedges_normals = transform.transform_affine(wedges_normals)
# Cast to Tensorflow type
# [num_wedges, 3]
wedges_origin = mi_to_tf_tensor(wedges_origin, rdtype)
# [num_wedges, 3]
wedges_e_hat = mi_to_tf_tensor(wedges_e_hat, rdtype)
# [num_wedges*2, 3]
wedges_normals = mi_to_tf_tensor(wedges_normals, rdtype)
# [num_wedges, 2, 3]
wedges_normals = tf.reshape(wedges_normals, [-1, 2, 3])
# Updates the wedges
wedges_ind = tf.expand_dims(wedges_ind, axis=1)
solver_paths.wedges_origin.scatter_nd_update(wedges_ind, wedges_origin)
solver_paths.wedges_e_hat.scatter_nd_update(wedges_ind, wedges_e_hat)
solver_paths.wedges_normals.scatter_nd_update(wedges_ind,
wedges_normals)
self._orientation = new_orient
# Trigger scene callback
self._scene.scene_geometry_updated()
def look_at(self, target):
# pylint: disable=line-too-long
r"""
Sets the orientation so that the x-axis points toward an
``Object``.
Input
-----
target : [3], float | :class:`sionna.rt.Object` | str
A position or the name or instance of an
:class:`sionna.rt.Object` in the scene to point toward to
"""
# Get position to look at
if isinstance(target, str):
obj = self.scene.get(target)
if not isinstance(obj, Object):
raise ValueError(f"No camera, device, or object named '{target}' found.")
else:
target = obj.position
elif isinstance(target, Object):
target = target.position
else:
target = tf.cast(target, dtype=self._rdtype)
if not target.shape[0]==3:
raise ValueError("`target` must be a three-element vector)")
# Compute angles relative to LCS
x = target - self.position
x, _ = normalize(x)
theta, phi = theta_phi_from_unit_vec(x)
alpha = phi # Rotation around z-axis
beta = theta-PI/2 # Rotation around y-axis
gamma = 0.0 # Rotation around x-axis
self.orientation = (alpha, beta, gamma)