In [None]:
import sys
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/ShapeAXI/')
import torch
from torch import nn
import shapeaxi
from shapeaxi import utils
import plotly.graph_objects as go
import plotly.express as px
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import (Meshes, Pointclouds)
import json
import random

from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene

from pytorch3d.renderer import (
    TexturesVertex
)

In [None]:
def plot_pointcloud(mesh, title="", n_points=5000):
    points = sample_points_from_meshes(mesh, n_points)
    # points = mesh.verts_packed()
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(
        size=2,
        color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
    fig.show()

In [None]:
mount_point = '/mnt/raid/home/jprieto/3DTeethSeg'

surf = utils.ReadSurf(f'{mount_point}/lower/GSHA8E4C/GSHA8E4C_lower.obj')
labels = json.loads(open(f'{mount_point}/lower/GSHA8E4C/GSHA8E4C_lower.json').read())
landmarks = json.loads(open(f'{mount_point}/Batch_2_4_23_24/GSHA8E4C_lower__kpt.json').read())

In [None]:
def generate_random_color():
    return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

def assign_random_colors(labels):
    unique_labels = set(labels)
    color_map = {label: generate_random_color() for label in unique_labels}
    return color_map

color_mapping_mesh = assign_random_colors(labels["labels"])

colors_mesh = [color_mapping_mesh[label] for label in labels["labels"]]

color_mapping_landmarks = assign_random_colors([obj["key"] for obj in landmarks["objects"]])
colors_landmarks = [color_mapping_landmarks[obj["key"]] for obj in landmarks["objects"]]

In [None]:
V, F = utils.PolyDataToTensors_v_f(surf)
textures = TexturesVertex(torch.tensor(colors_mesh).unsqueeze(0).to(torch.float)/255.0)
mesh = Meshes(verts=V.unsqueeze(0), faces=F.unsqueeze(0), textures=textures)

landmarks_coords = torch.tensor([obj["coord"] for obj in landmarks["objects"]])
landmarks_pc = Pointclouds(points=landmarks_coords.unsqueeze(0), features=torch.tensor(colors_landmarks).unsqueeze(0).to(torch.float)/255.0)

In [None]:
fig = plot_scene({
    "Dental Challenge": {
        "dental": mesh,
        "landmarks": landmarks_pc
    }
}, pointcloud_marker_size=10)
fig.show()

In [None]:
plot_pointcloud(mesh, n_points=128)

In [None]:
import ocnn
from ocnn.octree import Octree, Points

class Transform:
  def __init__(self, depth: int, full_depth: int, distort: bool, angle: list,
               interval: list, scale: float, uniform: bool, jitter: float,
               flip: list, orient_normal: str = '', **kwargs):
    super().__init__()

    # for octree building
    self.depth = depth
    self.full_depth = full_depth

    # for data augmentation
    self.distort = distort
    self.angle = angle
    self.interval = interval
    self.scale = scale
    self.uniform = uniform
    self.jitter = jitter
    self.flip = flip

    # for other transformations
    self.orient_normal = orient_normal

  def __call__(self, sample: dict, idx: int):
    r''''''

    output = self.preprocess(sample, idx)
    output = self.transform(output, idx)
    output['octree'] = self.points2octree(output['points'])
    return output

  def preprocess(self, sample: dict, idx: int):
    r''' Transforms :attr:`sample` to :class:`Points` and performs some specific
    transformations, like normalization.
    '''

    xyz = torch.from_numpy(sample.pop('points'))
    normals = torch.from_numpy(sample.pop('normals'))
    sample['points'] = Points(xyz, normals)
    return sample

  def transform(self, sample: dict, idx: int):
    r''' Applies the general transformations provided by :obj:`ocnn`.
    '''

    # The augmentations including rotation, scaling, and jittering.
    points = sample['points']
    if self.distort:
      rng_angle, rng_scale, rng_jitter, rnd_flip = self.rnd_parameters()
      points.flip(rnd_flip)
      points.rotate(rng_angle)
      points.translate(rng_jitter)
      points.scale(rng_scale)

    if self.orient_normal:
      points.orient_normal(self.orient_normal)

    # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree
    inbox_mask = points.clip(min=-1, max=1)
    sample.update({'points': points, 'inbox_mask': inbox_mask})
    return sample

  def points2octree(self, points: Points):
    r''' Converts the input :attr:`points` to an octree.
    '''

    octree = Octree(self.depth, self.full_depth)
    octree.build_octree(points)
    return octree

  def rnd_parameters(self):
    r''' Generates random parameters for data augmentation.
    '''

    rnd_angle = [None] * 3
    for i in range(3):
      rot_num = self.angle[i] // self.interval[i]
      rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,))
      rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0)
    rnd_angle = torch.cat(rnd_angle)

    rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0
    if self.uniform:
      rnd_scale[1] = rnd_scale[0]
      rnd_scale[2] = rnd_scale[0]

    rnd_flip = ''
    for i, c in enumerate('xyz'):
      if torch.rand([1]) < self.flip[i]:
        rnd_flip = rnd_flip + c

    rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter
    return rnd_angle, rnd_scale, rnd_jitter, rnd_flip


class CollateBatch:
  r''' Merge a list of octrees and points into a batch.
  '''

  def __init__(self, merge_points: bool = False):
    self.merge_points = merge_points

  def __call__(self, batch: list):
    assert type(batch) == list

    outputs = {}
    for key in batch[0].keys():
      outputs[key] = [b[key] for b in batch]

      # Merge a batch of octrees into one super octree
      if 'octree' in key:
        octree = ocnn.octree.merge_octrees(outputs[key])
        # NOTE: remember to construct the neighbor indices
        octree.construct_all_neigh()
        outputs[key] = octree

      # Merge a batch of points
      if 'points' in key and self.merge_points:
        outputs[key] = ocnn.octree.merge_points(outputs[key])

      # Convert the labels to a Tensor
      if 'label' in key:
        outputs['label'] = torch.tensor(outputs[key])

    return outputs

In [None]:
mount_point = "/mnt/raid/home/jprieto"

In [None]:
surf = utils.ReadSurf(f'{mount_point}/ModelNet40/airplane/train/airplane_0129.off')
surf = utils.GetUnitSurf(surf)
V, F = utils.PolyDataToTensors_v_f(surf)
N = utils.GetNormalsTensor(surf)

octree_0 = Octree(16)
octree_0.build_octree(Points(V, normals=N))

print(V.shape, N.shape)


In [None]:
surf = utils.ReadSurf(f'{mount_point}/ModelNet40/airplane/test/airplane_0656.off')
surf = utils.GetUnitSurf(surf)
V, F = utils.PolyDataToTensors_v_f(surf)
N = utils.GetNormalsTensor(surf)
octree_1 = Octree(16)
octree_1.build_octree(Points(V, normals=N))
print(V.shape, N.shape)

In [None]:
octree = ocnn.octree.merge_octrees([octree_0, octree_1])
# NOTE: remember to construct the neighbor indices
octree.construct_all_neigh()


In [None]:
x, y, z, b = octree.xyzb(12)
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(
        size=2,
        color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
fig.show()

print(x.shape, y.shape, z.shape, b.shape)

In [None]:
resnet = ocnn.models.ResNet(in_channels=6, out_channels=1280, resblock_num=1, stages=3, nempty=False)

In [None]:
resnet(octree.get_input_feature('NP').to(torch.float), octree=octree, depth=16).shape