In [1]:
import torch
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.path import Path
import torch.nn.functional as F

In [None]:
min_xyz = -4
max_xyz = 4
axis_len = max_xyz - min_xyz
axis_num_pts = 80
x = torch.linspace(min_xyz, max_xyz, axis_num_pts)
y = torch.linspace(min_xyz, max_xyz, axis_num_pts)
z = torch.linspace(min_xyz, max_xyz, axis_num_pts)
X, Y, Z = torch.meshgrid(x, y, z)
stacked_tensor = torch.stack([X, Y, Z], dim=3)
reshaped_tensor = torch.reshape(stacked_tensor, (-1, 3))

In [3]:
def rotate_xyz(p, angles):
    '''
    Reference: http://www.songho.ca/opengl/gl_anglestoaxes.html
    '''
    θx, θy, θz = angles
    θx = torch.FloatTensor([θx * 3.14159 / 180.0])
    θy = torch.FloatTensor([θy * 3.14159 / 180.0])
    θz = torch.FloatTensor([θz * 3.14159 / 180.0])
    R_zyx = torch.tensor([ 
        [torch.cos(θz)*torch.cos(θy), -torch.sin(θz)*torch.cos(θx) + torch.cos(θz)*torch.sin(θy)*torch.sin(θx), torch.sin(θz)*torch.sin(θx)+torch.cos(θz)*torch.sin(θy)*torch.cos(θx)],
        [torch.sin(θz)*torch.cos(θy), torch.cos(θz)*torch.cos(θx)+torch.sin(θz)*torch.sin(θy)*torch.sin(θx), -torch.cos(θz)*torch.sin(θx)+torch.sin(θz)*torch.sin(θy)*torch.cos(θx)],
        [-torch.sin(θy), torch.cos(θy)*torch.sin(θx), torch.cos(θy)*torch.cos(θx)]
    ])
    return torch.matmul(p, R_zyx.T)

## Operations

In [21]:
'''
compute the element-wise maximum
intersection, _ = torch.max(torch.stack([signed_distances1, signed_distances2]), dim=0)
union = torch.min(link1_signed_distances, link2_signed_distances)
'''

# Sphere

In [4]:
def sdf_sphere(p, r, center, angles):
    p_rotated = rotate_xyz(p, angles)
    return torch.sqrt(torch.sum((p_rotated - center)**2, dim=1)) - r

In [5]:
r=1
center = torch.tensor([1, 1, 1])
angles = torch.tensor([0, 0, 0])
sphere1_signed_distances_vec = sdf_sphere(reshaped_tensor, r, center, angles)
occupancy_vec = ()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Get the indices of the nonzero elements in the occupancy_vec tensor
ax.scatter(reshaped_tensor[sphere1_signed_distances_vec <= 0, 0], reshaped_tensor[sphere1_signed_distances_vec <= 0,1], reshaped_tensor[sphere1_signed_distances_vec <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
# ax.view_init(elev=30, azim=100)
plt.show()

# Box

In [34]:
def sdBox(p, b, center, angles):
    '''
    p = point to be tested, nx3 tensor
    b = half the length of box, i.e. coordinates of the corner from the center
    center = center of the box
    angles = rotation angles in degrees
    '''
    p_rotated = rotate_xyz(p, angles)
    q = torch.abs(p_rotated-center) - b
    return torch.norm(torch.max(q, torch.tensor([0.,0.,0.])), dim=1) + torch.min(torch.max(q, dim=1)[0], torch.tensor([0.]))

In [26]:
b = torch.tensor([1, 1, 1])
center = torch.tensor([1, 1, 1])
angles = torch.tensor([30, 0, 0])
box_signed_distances = sdBox(reshaped_tensor, b, center, angles)

box_signed_distances shape: torch.Size([1000000])


In [None]:
fig = plt.figure()
ax2 = fig.add_subplot(111, projection='3d')
ax2.scatter(reshaped_tensor[box_signed_distances <= 0, 0], reshaped_tensor[box_signed_distances <= 0,1], reshaped_tensor[box_signed_distances <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
# ax.view_init(elev=30, azim=100)
plt.show()

# Torus

In [72]:
def sdTorus(p, t, center, angles):
    '''
    p = point to be tested, nx3 tensor
    t = torus parameters as tuple (r, R)
    center = center of the torus
    angles = rotation angles in degrees
    '''
    no_points = p.shape[0]
    p = rotate_xyz(p, angles)
    p = p - center
    t = t.reshape(1,2)
    temp1 = (torch.norm(p[:, [0,2]], dim=1) - t[:, 0]).reshape(no_points, 1)
    temp2 = p[:, 1].reshape(no_points, 1)
    q = torch.cat([temp1, temp2], dim=1)
    temp3 = (torch.norm(q, dim=1)).reshape(no_points, 1) - t[:, 1]
    return temp3.squeeze()


In [73]:
radius = torch.tensor([3,2])
center = torch.tensor([0, 0, 0]).reshape(1,3)
angles = torch.tensor([0, 0, 0])
torus_signed_distances = sdTorus(reshaped_tensor, radius, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[torus_signed_distances <= 0, 0], reshaped_tensor[torus_signed_distances <= 0,1], reshaped_tensor[torus_signed_distances <= 0, 2], c='blue', marker='o')
ax.view_init(elev=10, azim=90)
plt.show()

# Triangular Prism

In [84]:

def sdTriPrism(p, h, center, angles):
    '''
    p = point to be tested, nx3 tensor
    h = half height and half width of base triangle, 2-element tensor
    center = center of the triangular prism
    angles = rotation angles in degrees
    Note: The constant 0.866025 in the code is the square root of 3 divided by 2, and is used to simplify the computation of the signed distance function. The 0.866025 factor is the tangent of 60 degrees, the angle between the side face and the base of the prism.
    '''
    # rotate point
    p_rotated = rotate_xyz(p, angles)

    # translate to center
    p_translated = p_rotated - center
    
    # compute signed distance
    q = torch.abs(p_translated)   # nx3 tensor having absolute values of x, y, z coordinates of p_translated
    d = torch.max(torch.cat([q[:, [2]]-h[1], (q[:, [0]]*0.866025 + p_translated[:, [1]]*0.5), (-p_translated[:, [1]])-h[0]*0.5], dim=1), dim=1)[0]

    return d

In [105]:
h = torch.tensor([4,4])
center = torch.tensor([0, 0, 0])
angles = torch.tensor([0, 0, 0])
triPrism_signed_distances = sdTriPrism(reshaped_tensor, h, center, angles)

In [None]:
fig = plt.figure()
ax2 = fig.add_subplot(111, projection='3d')
ax2.scatter(reshaped_tensor[triPrism_signed_distances <= 0, 0], reshaped_tensor[triPrism_signed_distances <= 0,1], reshaped_tensor[triPrism_signed_distances <= 0, 2], c='blue', marker='o')
ax2.view_init(elev=30, azim=30)
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
plt.show()

# Cone

In [6]:
def sdCone(p, c, h, center=torch.tensor([0, 0, 0]), angles=torch.tensor([0, 0, 0])):
    '''
    c is the sin/cos of the angle, h is height. Alternatively pass q instead of (c,h), which is the point at the base in 2D q = torch.tensor([h * c[0] / c[1], -h, 0.0])
    '''
    reshape_no = p.shape[0]
    q = torch.tensor([h * c[0] / c[1], -h]).reshape(2,1)
    q_diag = torch.diag(q.squeeze())

    # Apply rotation and translation
    p_rotated = rotate_xyz(p, angles)
    p_rotated_centered = p_rotated - center

    # Concatenate norms and elem1 along the second dimension
    w = torch.cat((torch.norm(p_rotated_centered[:, [0, 2]], dim=1).unsqueeze(1), p_rotated_centered[:, 1].unsqueeze(1)), dim=1)
    
    temp2 = torch.clamp(torch.matmul(w, q) / torch.matmul(q.T, q), 0.0, 1.0).reshape(reshape_no, 1)    
    a = w - torch.matmul(torch.cat((temp2, temp2), dim=1), q_diag)

    b = w - torch.cat((q[0] * torch.clamp(w[:,0] / q[0], 0.0, 1.0).reshape(reshape_no, 1), q[1] * torch.ones(reshape_no, 1)), dim=1)

    k = torch.sign(q[1])
    d = torch.min(a[:,0] * a[:,0] + a[:,1] * a[:,1], b[:,0] * b[:,0] + b[:,1] * b[:,1])

    s = torch.max(k * (w[:, 0] * q[1] - w[:, 1] * q[0]), k * (w[:, 1] - q[1]))
    return torch.sqrt(d) * torch.sign(s)

In [7]:
α = 45
α = torch.FloatTensor([α * 3.141592653589793 / 180])
c = torch.FloatTensor([torch.sin(α/2), torch.cos(α/2)])
h = torch.FloatTensor([3.0])
center = torch.FloatTensor([0, 0, 0])
angles = torch.FloatTensor([0, 0, 30])
cone_signed_distances = sdCone(reshaped_tensor, c, h, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[cone_signed_distances <= 0, 0], reshaped_tensor[cone_signed_distances <= 0,1], reshaped_tensor[cone_signed_distances <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.view_init(elev=120, azim=-90)
plt.show()

# Hexagonal Prism

In [19]:
def sdHexPrism(p, h, center, angles):
    p = rotate_xyz(p, angles)
    p = p - center
    no_point = p.shape[0]
    k = torch.tensor([-0.8660254, 0.5, 0.57735]).reshape(1, 3)
    p = torch.abs(p)
    
    min_dot = torch.min(k[0,0]*p[:, 0] + k[0,1]*p[:, 1], torch.tensor(0.0)).reshape(no_point, 1)
    p[:, :2] -= 2.0 * torch.matmul(min_dot, k[0, :2].reshape(1, 2))

    d1 = torch.norm(p[:, :2] - torch.clamp(p[:, 0:2], -k[0, 2] * h[0], k[0,2] * h[0]), dim=1) * torch.sign(p[:, 1] - h[0])
    d2 = p[:, 2] - h[1]
    d = torch.stack((d1, d2), dim=1)
    return torch.min(torch.max(d[:, 0], d[:, 1]), torch.tensor(0.0)) + torch.norm(torch.max(d, torch.tensor(0.0)), dim=1)

In [20]:
h = torch.FloatTensor([1.0,4.0])
center = torch.FloatTensor([0, 0, 0])
angles = torch.FloatTensor([90, 0, 90])
# p = torch.FloatTensor([[1, 1, 1], [-1, -2, 2], [1, -2, 3], [1, 3, 2], [2, 1, 2], [3, 4, 5]])
hexPrism_signed_distances = sdHexPrism(reshaped_tensor, h, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[hexPrism_signed_distances <= 0, 0], reshaped_tensor[hexPrism_signed_distances <= 0,1], reshaped_tensor[hexPrism_signed_distances <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
# ax.view_init(elev=45, azim=45)
plt.show()

# Capsule

In [23]:
def sdCapsule(p, a, b, r, center, angles):
    """
    Calculates the signed distance of a point or points to a capsule defined by two end points (a, b) and a radius (r).
    
    Arguments:
    p -- a tensor of size (n, 3) representing the point(s) to calculate the signed distance for
    a -- a tensor of size (3,) representing the first end point of the capsule
    b -- a tensor of size (3,) representing the second end point of the capsule
    r -- a float representing the radius of the capsule
    
    Returns:
    A 1D tensor of size (n,) representing the signed distance(s) of the point(s) to the capsule
    """
    # rotate point
    p = rotate_xyz(p, angles)
    # translate point
    p = p - center

    no_points = p.shape[0]
    a = a.expand(no_points, -1)  # -1 means expand along that dimension
    b = b.expand(no_points, -1)
    r = r.expand(no_points, -1)
    pa = p - a
    ba = b - a
    h = torch.clamp(torch.sum(pa * ba, dim=1) / torch.sum(ba * ba, dim = 1), 0, 1).reshape(no_points, 1)
    return (torch.norm(pa - ba * h, dim=1).reshape(no_points, 1) - r).squeeze()

In [28]:
a = torch.FloatTensor([0,0, 3])
b = torch.FloatTensor([0,0, -3])
r = torch.FloatTensor([2])
center = torch.FloatTensor([2, 0, 0])
angles = torch.FloatTensor([30, 0, 0])
capsule_signed_distances = sdCapsule(reshaped_tensor, a, b, r, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[capsule_signed_distances <= 0, 0], reshaped_tensor[capsule_signed_distances <= 0,1], reshaped_tensor[capsule_signed_distances <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.view_init(elev=30, azim=10)
plt.show()

# Verticle Capsule

In [43]:
def sdCapsule2(p, h, r, center, angles):
    no_points = p.shape[0]
    r = r.expand(no_points, -1)

    # rotate point
    p = rotate_xyz(p, angles)
    # translate point
    p = p - center

    p[:,1] -= torch.clamp(p[:,1], torch.FloatTensor([0]), h)
    return (torch.norm(p, dim=1).reshape(no_points,1) - r).squeeze()

In [44]:
h = torch.FloatTensor([5])
r = torch.FloatTensor([3])
center = torch.FloatTensor([0, 0, 0])
angles = torch.FloatTensor([0, 0, 0])
capsule_signed_distances2 = sdCapsule2(reshaped_tensor, h, r, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[capsule_signed_distances2 <= 0, 0], reshaped_tensor[capsule_signed_distances2 <= 0,1], reshaped_tensor[capsule_signed_distances2 <= 0, 2], c='blue', marker='o')
# ax.view_init(elev= 20, azim=0)
plt.show()


# Octahedron - bound (not exact)

In [30]:
def sdOctahedron(p, s, center, angles):
    '''
    s = size of the octahedron
    Constant 0.57735027 = 1/sqrt(3). This scaling factor is necessary to ensure that the distance function is normalized, meaning that the distance at the surface of the shape is equal to 1.
    '''
    no_points = p.shape[0]
    # rotate point
    p = rotate_xyz(p, angles)
    # translate point
    p = p - center

    q = torch.abs(p)
    return (q[:,0] + q[:,1] + q[:,2] - s)*0.57735027

In [40]:
s = torch.FloatTensor([3])
center = torch.FloatTensor([0, 0, 0])
angles = torch.FloatTensor([0, 0, 0])
octahedron_signed_distances = sdOctahedron(reshaped_tensor, s, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[octahedron_signed_distances <= 0, 0], reshaped_tensor[octahedron_signed_distances <= 0,1], reshaped_tensor[octahedron_signed_distances <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.view_init(elev= 30, azim=30)
plt.show()

# Ellipsoid - bound (not exact!)

In [49]:
def sdEllipsoid(p, r, center, angles):
    '''
    r: 3D vector r that specifies the radii of the ellipsoid along each axis

    '''
    no_points = p.shape[0]
    # rotate point
    p = rotate_xyz(p, angles)
    # translate point
    p = p - center

    r = torch.reshape(r, (1,3))
    k0 = torch.norm(p/r, dim=1).reshape(no_points, 1)
    k1 = torch.norm(p/(r*r), dim=1).reshape(no_points, 1)
    return (k0*(k0-1)/k1).squeeze()

In [50]:
r = torch.FloatTensor([3, 1.5, 0.5])
center = torch.FloatTensor([0, 0, 0])
angles = torch.FloatTensor([0, 0, 0])
ellipsoid_signed_distances = sdEllipsoid(reshaped_tensor, r, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[ellipsoid_signed_distances <= 0, 0], reshaped_tensor[ellipsoid_signed_distances <= 0,1], reshaped_tensor[ellipsoid_signed_distances <= 0, 2], c='blue', marker='o')
# ax.view_init(elev= 0, azim=0)
plt.show()

# Link - exact 

In [5]:
def sdLink(p, le, r1, r2, center, angles):
    '''
    le: link length
    r1: radius of the first cylinder
    r2: radius of the second cylinder
    '''
    no_points = p.shape[0]
    p = rotate_xyz(p, angles)
    p = p - center
    t1 = torch.max(torch.abs(p[:,1]).unsqueeze(1) -le, torch.tensor([0.0]))              
    q = torch.cat((p[:, 0].unsqueeze(1), t1, p[:,2].unsqueeze(1)), dim=1)
    t2 = torch.norm(q[:, 0:2], dim=1).reshape(no_points, 1) - r1
    t3 = torch.cat((t2, q[:, 2].unsqueeze(1)), dim = 1)
    temp2 = torch.norm(t3, dim=1).reshape(no_points, 1) - r2
    return temp2.squeeze()

In [54]:
le = torch.FloatTensor([2])
r1 = torch.FloatTensor([1.5])
r2 = torch.FloatTensor([1])
center = torch.FloatTensor([0, 0, 0])
angles = torch.FloatTensor([0, 0, 0])
p = torch.FloatTensor([[1,2,3], [4,-0.5,6], [7,8,9], [10,11,12]])
link_signed_distances = sdLink(reshaped_tensor, le, r1, r2, center, angles)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reshaped_tensor[link_signed_distances <= 0, 0], reshaped_tensor[link_signed_distances <= 0,1], reshaped_tensor[link_signed_distances <= 0, 2], c='blue', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.view_init(elev= 80, azim=10)
plt.show()

# SDF Scene

In [23]:
def sdScene(parameters):
    final_signed_distances = []
    num_shapes = 2

    # sphere 1
    r=0.8
    center = parameters[0, 0:3]
    angles = parameters[0, 3:6]
    sphere1_signed_distances = sdf_sphere(reshaped_tensor, r, center, angles)
    final_signed_distances.append(sphere1_signed_distances)

    # sphere 2
    r=1
    center = parameters[1, 0:3]
    angles = parameters[1, 3:6]
    sphere1_signed_distances = sdf_sphere(reshaped_tensor, r, center, angles)
    final_signed_distances.append(sphere1_signed_distances)
    
    # sphere 3
    r=1
    center = parameters[2, 0:3]
    angles = parameters[2, 3:6]
    sphere1_signed_distances = sdf_sphere(reshaped_tensor, r, center, angles)
    final_signed_distances.append(sphere1_signed_distances)
    
    # sphere 4
    r=1
    center = parameters[3, 0:3]
    angles = parameters[3, 3:6]
    sphere1_signed_distances = sdf_sphere(reshaped_tensor, r, center, angles)
    final_signed_distances.append(sphere1_signed_distances)

    # sphere 5
    r=1
    center = parameters[4, 0:3]
    angles = parameters[4, 3:6]
    sphere1_signed_distances = sdf_sphere(reshaped_tensor, r, center, angles)
    final_signed_distances.append(sphere1_signed_distances)

    # link 1
    le = torch.FloatTensor([0.7])
    r1 = torch.FloatTensor([0.375])
    r2 = torch.FloatTensor([0.20])
    center = parameters[1, 0:3]
    angles = parameters[1, 3:6]
    link1_signed_distances = sdLink(reshaped_tensor, le, r1, r2, center, angles)
    final_signed_distances.append(link1_signed_distances)

    return num_shapes, final_signed_distances

In [7]:

def define_shape(points_2d):
    # Define the vertices of the triangle
    vertices = [(-6, -6), (6, 0), (0, 6)]

    # Create a Path object from the vertices
    triangle = Path(vertices)

    # Check if each point is inside or outside the triangle
    # inside = triangle.contains_points(points_2d.iloc[:, 0:2])
    inside = triangle.contains_points(points_2d)
    inside = inside.astype(float)
    inside = torch.tensor(inside)
    
    # returns the (points_2d x 1) tensor having values 1 if the point is inside the triangle and 0 otherwise
    return inside

In [8]:
point_xy = reshaped_tensor[:, 0:2]
point_xy = torch.unique(point_xy, dim=0)
point_xy.shape

torch.Size([6400, 2])

In [9]:
inside = define_shape(point_xy)
inside.shape
pts_inside = torch.cat([point_xy, inside.reshape(point_xy.shape[0], 1)], dim=1)
pts_inside.shape

torch.Size([6400, 3])

In [10]:
def get_union(final_signed_distances):
    # stacked_final_signed_distances is (num_points, num_shapes)
    stacked_final_signed_distances = torch.stack(final_signed_distances, dim=0)
    
    # stacked_final_signed_distances_01 = torch.where(stacked_final_signed_distances <= 0, torch.tensor(1), torch.tensor(0))
    union_sdf = torch.min(stacked_final_signed_distances, dim=0)[0]
    union_sdf_01 = torch.where(union_sdf <= 0, torch.tensor(1), torch.tensor(0))
    print('union_sdf.shape: ', union_sdf_01.shape)
    print('union_sdf: ', union_sdf_01)
    return union_sdf_01

In [11]:
# for final_signed_distances is a tensor of shape (num_points, num_shapes)
def project_union_sdf2(final_signed_distances):
    union_sdf = torch.min(final_signed_distances, dim=1)[0]
    reshaped_tensor_union = torch.cat([reshaped_tensor, union_sdf.reshape(reshaped_tensor.shape[0], 1)], dim=1)
    point_xy_union = torch.cat([point_xy, torch.zeros(point_xy.shape[0], 1)], dim=1)

    j = 0
    for i in range(0, reshaped_tensor_union.shape[0], axis_num_pts):
        point_xy_union[j, 2] = 1 if (reshaped_tensor_union[i:i+axis_num_pts, 3].min(dim=0)[0] < 0) else 0
        j += 1
    
    # returns (n2d x 3) whose 3rd column is 1 if the point is inside the union, 0 otherwise 
    return point_xy_union

# # for final_signed_distances is list
# def project_union_sdf2(final_signed_distances):
#     stacked_final_signed_distances = torch.stack(final_signed_distances, dim=0)
#     union_sdf = torch.min(stacked_final_signed_distances, dim=0)[0]
#     reshaped_tensor_union = torch.cat([reshaped_tensor, union_sdf.reshape(reshaped_tensor.shape[0], 1)], dim=1)
#     point_xy_union = torch.cat([point_xy, torch.zeros(point_xy.shape[0], 1)], dim=1)

#     j = 0
#     for i in range(0, reshaped_tensor_union.shape[0], axis_num_pts):
#         point_xy_union[j, 2] = 1 if (reshaped_tensor_union[i:i+axis_num_pts, 3].min(dim=0)[0] < 0) else 0
#         j += 1
#     # returns (n2d x 3) whose 3rd column is 1 if the point is inside the union, 0 otherwise 
#     return point_xy_union

In [13]:
def plot_union(parameters):
    # visualize original union sdf before optimization
    num_shapes, final_signed_distances = sdScene(parameters)
    stacked_final_signed_distances = torch.stack(final_signed_distances, dim=0)
    union_sdf = torch.min(stacked_final_signed_distances, dim=0)[0]

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(reshaped_tensor[union_sdf <= 0, 0], reshaped_tensor[union_sdf <= 0,1], reshaped_tensor[union_sdf <= 0, 2], c='blue', marker='o')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    # ax.view_init(elev= 90, azim=-90)
    plt.show()

In [14]:
def plot_intersections(parameters):
    # get intersection matrix and plot intersections
    num_shapes, final_signed_distances = sdScene(parameters)
    intersection_mat = [[None for _ in range(num_shapes)] for _ in range(num_shapes)]       # creating empty list of lists of size num_shapes x num_shape
    fig = plt.figure(figsize=(10, 10))  # create a figure object
    for i in range(num_shapes):
        for j in range(i+1, num_shapes):
            intersection_mat[i][j] = torch.max(final_signed_distances[i], final_signed_distances[j])
            ax = fig.add_subplot(num_shapes, num_shapes, i*num_shapes+j+1, projection='3d')
            ax.scatter(reshaped_tensor[intersection_mat[i][j] <= 0, 0], reshaped_tensor[intersection_mat[i][j] <= 0,1], reshaped_tensor[intersection_mat[i][j] <= 0, 2], c='blue', marker='o')
            
    # Set horizontal and vertical space between subplots
    plt.subplots_adjust(hspace=0.4, wspace=0.4)
    plt.show()

In [15]:
def plot_shape(parameters):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(point_xy[inside > 0, 0], point_xy[inside > 0,1], c='blue', marker='o')
    # ax.view_init(elev= 80, azim=10)
    plt.show()
    

In [16]:
def calulate_loss(parameters):
    intersection_loss = 0
    num_shapes, final_signed_distances = sdScene(parameters)
    final_signed_distances = torch.stack(final_signed_distances, dim = 1)
    print('final_signed_distances.shape: ', final_signed_distances.shape)

    intersection_mat = [[None for _ in range(num_shapes)] for _ in range(num_shapes)] # creating empty list of lists of size num_shapes x num_shape
    for i in range(num_shapes):
        for j in range(i+1, num_shapes):
            intersection_mat[i][j] = torch.max(final_signed_distances[:, i], final_signed_distances[:, j])
            intersection_loss += torch.sum(intersection_mat[i][j][intersection_mat[i][j] <= 0])

    intersection_loss = - intersection_loss/num_shapes
    print('Intersection Loss: ', intersection_loss)
    
    projection_loss = torch.square(project_union_sdf2(final_signed_distances.clone().detach().requires_grad_(True))[:, 2].clone().detach().requires_grad_(True) - inside).sum()
    print('Projection Loss: ', projection_loss, '\n')
    w1 = 1
    w2 = 1
    loss = w1*intersection_loss + w2*projection_loss
    return loss.float()

In [None]:
parameters = [[0.5,0.5,0.5,0,0,0], [2.5,0.5,0.5,0,0,0], [-2.5, -1, 2,0,0,0], [1,3,3,0,0,0], [-1, -3, 0.5,0,0,0], [3, 3, 3,0,0,0]]
parameters = torch.tensor(parameters, dtype=torch.float32, requires_grad=True)
plot_intersections(parameters)
plot_union(parameters)
plot_shape(parameters)
optimizer = torch.optim.Adam([parameters], lr=0.05)

for i in range(0):
    optimizer.zero_grad()
    loss = calulate_loss(parameters)
    loss.backward()

    print('Gradients:')
    for name, param in parameters.named_parameters():
        if param.grad is not None:
            print(name, param.grad.norm(2).item())

    print("Gradient before update:", parameters.grad)
    optimizer.step()
    print("Gradient before update:", parameters.grad)
    if(i%2 == 0):
        print(loss)
        print(parameters)

plot_intersections(parameters)
plot_union(parameters)