In [None]:
import lietorch
import torch
from gecco_torch.scene.gaussian_model import GaussianModel
from gecco_torch.utils.isotropic_gaussian import IsotropicGaussianSO3
from gecco_torch.utils.isotropic_plotting import visualize_so3_probabilities
import jaxlie
import jax.numpy as jnp
import matplotlib.pyplot as plt
import math

In [2]:
class LogUniformSchedule():
    """
    LogUniform noise schedule which seems to work better in our (GECCO) context.

    alle schedules returnen einfach nur für jedes n ein sigma, 
    sie werden gecalled mit schedule(samples) und samples hat shape (batchsize, num_points, 3)
    und dann gibt er für jedes element im batch ein sigma
    """

    def __init__(self, max: float, min: float = 0.002, low_discrepancy: bool = True):
        super().__init__()

        self.sigma_min = min
        self.sigma_max = max
        self.log_sigma_min = math.log(min)
        self.log_sigma_max = math.log(max)
        self.low_discrepancy = low_discrepancy

    def return_schedule(self,n):
        u = torch.linspace(0,1,n).cuda()
        sigma = (
            u * (self.log_sigma_max - self.log_sigma_min) + self.log_sigma_min
        ).exp()
        return sigma

In [3]:
gm = GaussianModel(3)
gm.load_ply("/globalwork/giese/gaussians/02691156/1a04e3eab45ca15dd86060f189eb133/point_cloud/iteration_10000/point_cloud.ply")
rotations = gm.get_rotation.detach()
rotations_xyzw = rotations[:, [1, 2, 3, 0]]

In [4]:
schedule = LogUniformSchedule(165)
noise_schedule = schedule.return_schedule(128)

In [None]:
for i,s in enumerate(noise_schedule):
    axis_angles = torch.vmap(lambda r,s_single: IsotropicGaussianSO3(r,
                                                                s_single,
                                                                force_small_scale=False).sample_one_vmap(),
                            randomness="different")(rotations_xyzw, s.repeat(rotations_xyzw.shape[0]))
    samples = (lietorch.SO3(rotations_xyzw) * lietorch.SO3.exp(axis_angles)).vec()
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='mollweide')
    visualize_so3_probabilities(jnp.array([jaxlie.SO3(x[[3,0,1,2]].cpu().numpy()).as_matrix() for x in samples]), 0.001, ax=ax,fig=fig);
    fig.savefig(f"out/rot_{i}.png")

In [6]:
import cv2

frame = cv2.imread('out/rot_1.png')
height, width, layers = frame.shape
video = cv2.VideoWriter('out/video.avi', cv2.VideoWriter_fourcc(*'DIVX'), 10, (width, height))

for i in range(0, 128):
    video.write(cv2.imread(f'out/rot_{i}.png'))

# cv2.destroyAllWindows()
video.release()