In [2]:
import sys
sys.path.append('/work/floda/source/ShapeAXI/')
import torch
from torch import nn
import shapeaxi
from shapeaxi import utils
import plotly.graph_objects as go
import plotly.express as px
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import (Meshes, Pointclouds)
import json
import random
import vtk

from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene

from pytorch3d.renderer import (
    TexturesVertex
)

AttributeError: partially initialized module 'lightning.fabric' has no attribute 'accelerators' (most likely due to a circular import)

In [3]:
import ocnn
from ocnn.octree import Octree, Points

class Transform:
  def __init__(self, depth: int, full_depth: int, distort: bool, angle: list,
               interval: list, scale: float, uniform: bool, jitter: float,
               flip: list, orient_normal: str = '', **kwargs):
    super().__init__()

    # for octree building
    self.depth = depth
    self.full_depth = full_depth

    # for data augmentation
    self.distort = distort
    self.angle = angle
    self.interval = interval
    self.scale = scale
    self.uniform = uniform
    self.jitter = jitter
    self.flip = flip

    # for other transformations
    self.orient_normal = orient_normal

  def __call__(self, sample: dict, idx: int):
    r''''''

    output = self.preprocess(sample, idx)
    output = self.transform(output, idx)
    output['octree'] = self.points2octree(output['points'])
    return output

  def preprocess(self, sample: dict, idx: int):
    r''' Transforms :attr:`sample` to :class:`Points` and performs some specific
    transformations, like normalization.
    '''

    xyz = torch.from_numpy(sample.pop('points'))
    normals = torch.from_numpy(sample.pop('normals'))
    sample['points'] = Points(xyz, normals)
    return sample

  def transform(self, sample: dict, idx: int):
    r''' Applies the general transformations provided by :obj:`ocnn`.
    '''

    # The augmentations including rotation, scaling, and jittering.
    points = sample['points']
    if self.distort:
      rng_angle, rng_scale, rng_jitter, rnd_flip = self.rnd_parameters()
      points.flip(rnd_flip)
      points.rotate(rng_angle)
      points.translate(rng_jitter)
      points.scale(rng_scale)

    if self.orient_normal:
      points.orient_normal(self.orient_normal)

    # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree
    inbox_mask = points.clip(min=-1, max=1)
    sample.update({'points': points, 'inbox_mask': inbox_mask})
    return sample

  def points2octree(self, points: Points):
    r''' Converts the input :attr:`points` to an octree.
    '''

    octree = Octree(self.depth, self.full_depth)
    octree.build_octree(points)
    return octree

  def rnd_parameters(self):
    r''' Generates random parameters for data augmentation.
    '''

    rnd_angle = [None] * 3
    for i in range(3):
      rot_num = self.angle[i] // self.interval[i]
      rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,))
      rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0)
    rnd_angle = torch.cat(rnd_angle)

    rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0
    if self.uniform:
      rnd_scale[1] = rnd_scale[0]
      rnd_scale[2] = rnd_scale[0]

    rnd_flip = ''
    for i, c in enumerate('xyz'):
      if torch.rand([1]) < self.flip[i]:
        rnd_flip = rnd_flip + c

    rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter
    return rnd_angle, rnd_scale, rnd_jitter, rnd_flip


class CollateBatch:
  r''' Merge a list of octrees and points into a batch.
  '''

  def __init__(self, merge_points: bool = False):
    self.merge_points = merge_points

  def __call__(self, batch: list):
    assert type(batch) == list

    outputs = {}
    for key in batch[0].keys():
      outputs[key] = [b[key] for b in batch]

      # Merge a batch of octrees into one super octree
      if 'octree' in key:
        octree = ocnn.octree.merge_octrees(outputs[key])
        # NOTE: remember to construct the neighbor indices
        octree.construct_all_neigh()
        outputs[key] = octree

      # Merge a batch of points
      if 'points' in key and self.merge_points:
        outputs[key] = ocnn.octree.merge_points(outputs[key])

      # Convert the labels to a Tensor
      if 'label' in key:
        outputs['label'] = torch.tensor(outputs[key])

    return outputs

In [16]:
from vtk.util.numpy_support import vtk_to_numpy

def GetNormalsTensor(surf):
    normals = utils.ComputeNormals(surf)
    normals = vtk_to_numpy(normals.GetPointData().GetNormals())
    return torch.tensor(normals, dtype=torch.double)

In [1]:
import shapeaxi.saxi_transforms as saxi_transforms

surf1 = "/CMF/data/floda/abcd-data-release-5.1/data/sub-NDARINVHTNVLRBR/sub-NDARINVHTNVLRBR_ses-2YearFollowUpYArm1/surf/lh.white.vtk"
surf2 = "/CMF/data/floda/abcd-data-release-5.1/data/sub-34/sub-NDARINVHTNVLRBR_ses-2YearFollowUpYArm1/surf/lh.white.vtk"

transform = saxi_transforms.UnitSurfTransform()

V1, F2 = utils.PolyDataToTensors_v_f(surf1)
V2, F2 = utils.PolyDataToTensors_v_f(surf2)
# N1 = GetNormalsTensor(surf1)
# N2 = GetNormalsTensor(surf2)

octree_1 = Octree(16)
octree_1.build_octree(Points(V1))

octree_2 = Octree(16)
octree_2.build_octree(Points(V2))

KeyboardInterrupt: 

In [None]:
octree = ocnn.octree.merge_octrees([octree_1, octree_2])
# NOTE: remember to construct the neighbor indices
octree.construct_all_neigh()

In [None]:
x, y, z, b = octree.xyzb(12)
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(
        size=2,
        color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
fig.show()

print(x.shape, y.shape, z.shape, b.shape)

In [None]:
resnet = ocnn.models.ResNet(in_channels=6, out_channels=1280, resblock_num=1, stages=3, nempty=False)

In [None]:
resnet(octree.get_input_feature('P').to(torch.float), octree=octree, depth=16).shape