In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation
%matplotlib qt

def add_disk(normal, point, radius, color):
    theta = np.linspace(0, 2*np.pi, 50)
    x = radius * np.cos(theta)
    y = radius * np.sin(theta)
    z = np.zeros_like(theta)

    if np.allclose(normal, [0, 0, 1]):
        x_rot, y_rot, z_rot = x, y, z
    else:
        z_axis = np.array([0, 0, 1])
        rot_axis = np.cross(z_axis, normal)
        rot_axis = rot_axis / np.linalg.norm(rot_axis)
        rot_angle = np.arccos(np.dot(z_axis, normal))
        
        points = np.vstack([x, y, z])
        rot_matrix = Rotation.from_rotvec(rot_angle * rot_axis).as_matrix()
        rotated_points = rot_matrix @ points
        x_rot, y_rot, z_rot = rotated_points

    x_disk = x_rot + point[0]
    y_disk = y_rot + point[1]
    z_disk = z_rot + point[2]
    ax.plot3D(x_disk, y_disk, z_disk, linestyle='-', color=color, label='Circle')

def xnorm(a):
    ad = math.sqrt(a[0]*a[0]+a[1]*a[1]+a[2]*a[2])
    return [a[0]/ad, a[1]/ad, a[2]/ad]

r=torch.tensor([5])
C1=torch.tensor([0,0,0])
C2=torch.tensor([4,4,4])

e1=[1/(3**(1/2)), 1/(3**(1/2)), 1/(3**(1/2))]
e2=[-2/(3**(1/2)), 1/(3**(1/2)), 1/(3**(1/2))]

#params = (torch.rand(1)*r).requires_grad_(), (torch.rand(1)*math.pi*2).requires_grad_(), (torch.rand(1)*r).requires_grad_(), (torch.rand(1)*math.pi*2).requires_grad_()
params = (torch.tensor([5.])).requires_grad_(), (torch.rand(1)*math.pi*2).requires_grad_(), (torch.tensor([5.])).requires_grad_(), (torch.rand(1)*math.pi*2).requires_grad_()

radius1, theta1, radius2, theta2 = params

def loss(radius1, theta1, radius2, theta2):
    radius1_clamped = torch.clamp(radius1, 0, r.item())
    radius2_clamped = torch.clamp(radius2, 0, r.item())
    
    a1 = F.normalize(torch.tensor([0,0,1], dtype=torch.float32), dim=0)
    n1 = F.normalize(torch.tensor(e1, dtype=torch.float32), dim=0)

    v1=F.normalize(torch.cross(n1, a1), dim=0)
    u1=F.normalize(torch.cross(n1, v1), dim=0)

    point1=C1+(radius1_clamped*torch.sin(theta1)*v1+radius1_clamped*torch.cos(theta1)*u1)

    a2 = F.normalize(torch.tensor([0,0,1], dtype=torch.float32), dim=0)
    n2 = F.normalize(torch.tensor(e2, dtype=torch.float32), dim=0)

    v2=F.normalize(torch.cross(n2, a2), dim=0)
    u2=F.normalize(torch.cross(n2, v2), dim=0)

    point2=C2+(radius2_clamped*torch.sin(theta2)*v2+radius2_clamped*torch.cos(theta2)*u2)

    loss = (torch.sum((point2-point1)**2))**0.5
    loss.backward()
    return loss, point1, point2

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

for i in range(0, 200):
    lr=0.1
    l, point1, point2 = loss(radius1, theta1, radius2, theta2)
    for p in params:
        p.data -= p.grad*lr
        p.grad.zero_()
    if i % 10 == 0:
        print('Loss Train Set:', l)
        ax.scatter(*(point1.clone().detach().numpy()), color='r', alpha=0.5)
        ax.scatter(*(point2.clone().detach().numpy()), color='g', alpha=0.5)

ax.scatter(*(point1.clone().detach().numpy()), color='k', alpha=1)
ax.scatter(*(point2.clone().detach().numpy()), color='k', alpha=1)
add_disk(xnorm(e1), C1.clone().detach().numpy(), r, 'r')
add_disk(xnorm(e2), C2.clone().detach().numpy(), r, 'g')
plt.show()

Loss Train Set: tensor(13.4121, grad_fn=<PowBackward0>)
Loss Train Set: tensor(4.0397, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.8479, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.6493, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.4893, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.3618, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.2613, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.1829, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.1222, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.0756, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.0399, grad_fn=<PowBackward0>)
Loss Train Set: tensor(2.0127, grad_fn=<PowBackward0>)
Loss Train Set: tensor(1.9920, grad_fn=<PowBackward0>)
Loss Train Set: tensor(1.9764, grad_fn=<PowBackward0>)
Loss Train Set: tensor(1.9645, grad_fn=<PowBackward0>)
Loss Train Set: tensor(1.9556, grad_fn=<PowBackward0>)
Loss Train Set: tensor(1.9488, grad_fn=<PowBackward0>)
Loss Train Set: tensor(1.9437, grad_fn=<PowBackward0>)
Loss Trai

  x = radius * np.cos(theta)
  y = radius * np.sin(theta)
