# Setup

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import open3d as o3d
import numpy as np

In [3]:
#Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Utils: From script/common.py

In [4]:
from torch.autograd import Variable

class switch(object):
    def __init__(self, value):
        self.value = value
        self.fall = False

    def __iter__(self):
        """Return the match method once, then stop"""
        yield self.match
        raise StopIteration

    def match(self, *args):
        """Indicate whether or not to enter a case suite"""
        if self.fall or not args:
            return True
        elif self.value in args:  # changed for v1.5, see below
            self.fall = True
            return True
        else:
            return False
def s2_grid(n_alpha, n_beta):
    '''
    :return: rings around the equator
    size of the kernel = n_alpha * n_beta
    '''
    beta = np.linspace(start=0, stop=np.pi, num=n_beta, endpoint=False) + np.pi / n_beta / 2
    # ele = np.arcsin(np.linspace(start=0, stop=1, num=n_beta / 2, endpoint=False) + 1 / n_beta / 4)
    # beta = np.concatenate([np.sort(-ele), ele])
    alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) + np.pi / n_alpha
    B, A = np.meshgrid(beta, alpha, indexing='ij')
    B = B.flatten()
    A = A.flatten()
    grid = np.stack((B, A), axis=1)
    return grid

def change_coordinates(coords, radius, p_from='C', p_to='S'):
    """
    Change Spherical to Cartesian coordinates and vice versa, for points x in S^2.

    In the spherical system, we have coordinates beta and alpha,
    where beta in [0, pi] and alpha in [0, 2pi]

    We use the names beta and alpha for compatibility with the SO(3) code (S^2 being a quotient SO(3)/SO(2)).
    Many sources, like wikipedia use theta=beta and phi=alpha.

    :param coords: coordinate array
    :param p_from: 'C' for Cartesian or 'S' for spherical coordinates
    :param p_to: 'C' for Cartesian or 'S' for spherical coordinates
    :return: new coordinates
    """
    if p_from == p_to:
        return coords
    elif p_from == 'S' and p_to == 'C':

        beta = coords[..., 0]
        alpha = coords[..., 1]
        r = radius

        out = np.empty(beta.shape + (3,))

        ct = np.cos(beta)
        cp = np.cos(alpha)
        st = np.sin(beta)
        sp = np.sin(alpha)
        out[..., 0] = r * st * cp  # x
        out[..., 1] = r * st * sp  # y
        out[..., 2] = r * ct  # z
        return out

    elif p_from == 'C' and p_to == 'S':

        x = coords[..., 0]
        y = coords[..., 1]
        z = coords[..., 2]

        out = np.empty(x.shape + (2,))
        out[..., 0] = np.arccos(z)  # beta
        out[..., 1] = np.arctan2(y, x)  # alpha
        return out

    else:
        raise ValueError('Unknown conversion:' + str(p_from) + ' to ' + str(p_to))

def get_voxel_coordinate(radius, rad_n, azi_n, ele_n):
    grid = s2_grid(n_alpha=azi_n, n_beta=ele_n)
    pts_xyz_on_S2 = change_coordinates(grid, radius, 'S', 'C')
    pts_xyz_on_S2 = np.expand_dims(pts_xyz_on_S2, axis=0).repeat(rad_n, axis=0)
    scale = np.reshape(np.arange(rad_n) / rad_n + 1 / (2 * rad_n), [rad_n, 1, 1])
    pts_xyz = scale * pts_xyz_on_S2
    return pts_xyz

def angles2rotation_matrix(angles):
    Rx = np.array([[1, 0, 0],
                   [0, np.cos(angles[0]), -np.sin(angles[0])],
                   [0, np.sin(angles[0]), np.cos(angles[0])]])
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                   [0, 1, 0],
                   [-np.sin(angles[1]), 0, np.cos(angles[1])]])
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                   [np.sin(angles[2]), np.cos(angles[2]), 0],
                   [0, 0, 1]])
    R = np.dot(Rz, np.dot(Ry, Rx))
    return R

def pad_image(input, kernel_size):
    """
    Circularly padding image for convolution
    :param input: [B, C, H, W]
    :param kernel_size:
    :return:
    """
    device = input.device
    if kernel_size % 2 == 0:
        pad_size = kernel_size // 2
        output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3)
        zeros_pad = torch.zeros([output.shape[0], output.shape[1], pad_size, output.shape[3]]).to(device)
        output = torch.cat([output, zeros_pad], dim=2)
    else:
        pad_size = (kernel_size - 1) // 2
        output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3)
        output = torch.cat([input[:, :, :, -pad_size:], output], dim=3)
        zeros_pad = torch.zeros([output.shape[0], output.shape[1], pad_size, output.shape[3]]).to(device)
        output = torch.cat([output, zeros_pad], dim=2)
        output = torch.cat([zeros_pad, output], dim=2)
    return output

def pad_image_3d(input, kernel_size):
    """
    Circularly padding image for convolution
    :param input: [B, C, D, H, W]
    :param kernel_size:
    :return:
    """
    device = input.device
    if kernel_size % 2 == 0:
        pad_size = kernel_size // 2
        output = torch.cat([input, input[:, :, :, :, 0:pad_size]], dim=4)
        zeros_pad = torch.zeros([output.shape[0], output.shape[1], output.shape[2], pad_size, output.shape[4]]).to(
            device)
        output = torch.cat([output, zeros_pad], dim=3)
    else:
        pad_size = (kernel_size - 1) // 2
        output = torch.cat([input, input[:, :, :, :, 0:pad_size]], dim=4)
        output = torch.cat([input[:, :, :, :, -pad_size:], output], dim=4)
        zeros_pad = torch.zeros([output.shape[0], output.shape[1], output.shape[2], pad_size, output.shape[4]]).to(
            device)
        output = torch.cat([output, zeros_pad], dim=3)
        output = torch.cat([zeros_pad, output], dim=3)
    return output

def var_to_invar(pts, rad_n, azi_n, ele_n):
    """
    :param pts: input points data, [B, N, nsample, 3]
    :param rad_n: radial number
    :param azi_n: azimuth number
    :param ele_n: elevator number
    :return:
    """
    device = pts.device
    B, N, nsample, C = pts.shape
    assert N == rad_n * azi_n * ele_n
    angle_step = np.array([0, 0, 2 * np.pi / azi_n])
    pts = pts.view(B, rad_n, ele_n, azi_n, nsample, C)

    R = np.zeros([azi_n, 3, 3])
    for i in range(azi_n):
        angle = -1 * i * angle_step
        r = angles2rotation_matrix(angle)
        R[i] = r
    R = torch.FloatTensor(R).to(device)
    R = R.view(1, 1, 1, azi_n, 3, 3).repeat(B, rad_n, ele_n, 1, 1, 1)
    new_pts = torch.matmul(pts, R.transpose(-1, -2))

    del R
    del pts

    return new_pts.view(B, -1, nsample, C)

import numpy as np
from scipy.spatial import cKDTree

def ball_query(pts, new_pts, radius, nsample):
    """
    :param pts: all points, [B, N, 3]
    :param new_pts: query points, [B, S, 3]
    :param radius: local spherical radius
    :param nsample: max sample number in local sphere
    :return: indices of sampled points around new_pts [B, S, nsample]
    """
    device = pts.device
    B, N, C = pts.shape
    _, S, _ = new_pts.shape

    # Create an empty tensor to hold the indices of sampled points
    sampled_indices = torch.zeros(B, S, nsample, dtype=torch.long, device=device)

    for b in range(B):
        # Calculate pairwise distances between all points and query points
        pts_b = pts[b]  # [N, 3]
        new_pts_b = new_pts[b]  # [S, 3]

        # Expand dimensions for broadcasting
        pts_exp = pts_b.unsqueeze(0)  # [1, N, 3]
        new_pts_exp = new_pts_b.unsqueeze(1)  # [S, 1, 3]

        # Compute squared distances
        dists_sq = torch.sum((pts_exp - new_pts_exp) ** 2, dim=-1)  # [S, N]

        # Find points within the radius
        mask = dists_sq <= radius ** 2
        for s in range(S):
            indices = torch.nonzero(mask[s]).squeeze(1)  # [num_points_in_sphere]

            if indices.numel() > nsample:
                # If there are more points than nsample, randomly sample
                indices = indices[torch.randperm(indices.size(0))[:nsample]]
            elif indices.numel() < nsample:
                # If there are fewer points than nsample, pad with zeros if indices is empty
                pad_size = nsample - indices.numel()
                if pad_size > 0:
                    if indices.numel() == 0:
                        # No points found, pad with zeros (or any other placeholder index like -1)
                        indices = torch.zeros(nsample, dtype=torch.long, device=device)
                    else:
                        # Repeat the last index to fill remaining spots
                        indices = torch.cat([indices, indices[-1].repeat(pad_size)])

            # Store the indices of the sampled points
            sampled_indices[b, s, :indices.numel()] = indices

    return sampled_indices

def grouping_operation(features, idx):
    """
    Agrupa características basadas en índices.

    :param features: Tensor de características con forma (B, C, N)
    :param idx: Tensor de índices con forma (B, npoint, nsample)
    :return: Tensor de características agrupadas con forma (B, C, npoint, nsample)
    """
    B, C, N = features.shape
    _, npoint, nsample = idx.shape

    # Expande 'idx' para que coincida con la forma de 'features'
    idx_expanded = idx.unsqueeze(1)  # Forma (B, 1, npoint, nsample)

    # Ajustar 'features' para usar 'torch.gather'
    features_expanded = features.unsqueeze(2).expand(-1, -1, npoint, -1)  # Forma (B, C, npoint, N)

    # Agrupa características usando los índices expandidos
    grouped_features = torch.gather(features_expanded, 3, idx_expanded)

    return grouped_features

def sphere_query(pts, new_pts, radius, nsample):
  """
  :param pts: all points, [B. N. 3]
  :param new_pts: query points, [B, S. 3]
  :param radius: local sperical radius
  :param nsample: max sample number in local sphere
  :return:
  """

  device = pts.device
  B, N, C = pts.shape
  _, S, _ = new_pts.shape

  pts = pts.contiguous()
  new_pts = new_pts.contiguous()
  group_idx = ball_query(pts, new_pts, radius, nsample)
  print(group_idx)
  #group_idx = pnt2.ball_query(radius, nsample, pts, new_pts)
  mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, nsample)
  mask = (group_idx == mask).float()
  mask[:, :, 0] = 0

  # C implementation
  pts_trans = pts.transpose(1, 2).contiguous()
  #new_points = pnt2.grouping_operation(pts_trans, group_idx)  # (B, 3, npoint, nsample)
  new_points = grouping_operation(pts_trans, group_idx)  # (B, 3, npoint, nsample)
  new_points = new_points.permute([0, 2, 3, 1])

  # replace the wrong points using new_pts
  mask = mask.unsqueeze(3).repeat([1, 1, 1, 3])
  # new_pts = new_pts.unsqueeze(2).repeat([1, 1, nsample + 1, 1])
  new_pts = new_pts.unsqueeze(2).repeat([1, 1, nsample, 1])
  n_points = new_points * (1 - mask).float() + new_pts * mask.float()

  del mask
  del new_points
  del group_idx
  del new_pts
  del pts
  del pts_trans

  return n_points

def l2_norm(input, axis=1):
    norm = torch.norm(input, p=2, dim=axis, keepdim=True)
    output = torch.div(input, norm)
    return output

def cal_Z_axis(local_cor, local_weight=None, ref_point=None):
    device = local_cor.device
    B, N, _ = local_cor.shape
    cov_matrix = torch.matmul(local_cor.transpose(-1, -2), local_cor) if local_weight is None \
        else Variable(torch.matmul(local_cor.transpose(-1, -2), local_cor * local_weight), requires_grad=True)
    #Z_axis = torch.symeig(cov_matrix, eigenvectors=True)[1][:, :, 0]
    Z_axis = torch.linalg.eigh(cov_matrix, UPLO="U")[1][:, :, 0]
    mask = (torch.sum(-Z_axis * ref_point, dim=1) < 0).float().unsqueeze(1)
    Z_axis = Z_axis * (1 - mask) - Z_axis * mask

    return Z_axis

def RodsRotatFormula(a, b):
    B, _ = a.shape
    device = a.device
    b = b.to(device)
    c = torch.cross(a, b)
    theta = torch.acos(F.cosine_similarity(a, b)).unsqueeze(1).unsqueeze(2)

    c = F.normalize(c, p=2, dim=1)
    one = torch.ones(B, 1, 1).to(device)
    zero = torch.zeros(B, 1, 1).to(device)
    a11 = zero
    a12 = -c[:, 2].unsqueeze(1).unsqueeze(2)
    a13 = c[:, 1].unsqueeze(1).unsqueeze(2)
    a21 = c[:, 2].unsqueeze(1).unsqueeze(2)
    a22 = zero
    a23 = -c[:, 0].unsqueeze(1).unsqueeze(2)
    a31 = -c[:, 1].unsqueeze(1).unsqueeze(2)
    a32 = c[:, 0].unsqueeze(1).unsqueeze(2)
    a33 = zero
    Rx = torch.cat(
        (torch.cat((a11, a12, a13), dim=2), torch.cat((a21, a22, a23), dim=2), torch.cat((a31, a32, a33), dim=2)),
        dim=1)
    I = torch.eye(3).to(device)
    R = I.unsqueeze(0).repeat(B, 1, 1) + torch.sin(theta) * Rx + (1 - torch.cos(theta)) * torch.matmul(Rx, Rx)
    return R.transpose(-1, -2)