In [1]:
import nibabel as nib
import torch
import os
import vtk
from vtk.util.numpy_support import vtk_to_numpy
import numpy as np
from vtk.util.numpy_support import numpy_to_vtk

class OFFReader():
	def __init__(self):
		FileName = None
		Output = None

	def SetFileName(self, fileName):
		self.FileName = fileName

	def GetOutput(self):
		return self.Output

	def Update(self):
		with open(self.FileName) as file:

			first_string = file.readline() # either 'OFF' or 'OFFxxxx xxxx x'

			if 'OFF' != first_string[0:3]:
				raise('Not a valid OFF header!')

			elif first_string[3:4] != '\n':
				new_first = 'OFF'
				new_second = first_string[3:]
				n_verts, n_faces, n_dontknow = tuple([int(s) for s in new_second.strip().split(' ')])		

			else:
				n_verts, n_faces, n_dontknow = tuple([int(s) for s in file.readline().strip().split(' ')])

			surf = vtk.vtkPolyData()
			points = vtk.vtkPoints()
			cells = vtk.vtkCellArray()

			for i_vert in range(n_verts):
				p = [float(s) for s in file.readline().strip().split(' ')]
				points.InsertNextPoint(p[0], p[1], p[2])

			for i_face in range(n_faces):
				
				t = [int(s) for s in file.readline().strip().split(' ')]

				if(t[0] == 1):
					vertex = vtk.vtkVertex()
					vertex.GetPointIds().SetId(0, t[1])
					cells.InsertNextCell(line)
				elif(t[0] == 2):
					line = vtk.vtkLine()
					line.GetPointIds().SetId(0, t[1])
					line.GetPointIds().SetId(1, t[2])
					cells.InsertNextCell(line)
				elif(t[0] == 3):
					triangle = vtk.vtkTriangle()
					triangle.GetPointIds().SetId(0, t[1])
					triangle.GetPointIds().SetId(1, t[2])
					triangle.GetPointIds().SetId(2, t[3])
					cells.InsertNextCell(triangle)

			surf.SetPoints(points)
			surf.SetPolys(cells)

			self.Output = surf

def ScaleSurf(surf, mean_arr = None, scale_factor = None, copy=True):
    if(copy):
        surf_copy = vtk.vtkPolyData()
        surf_copy.DeepCopy(surf)
        surf = surf_copy

    shapedatapoints = surf.GetPoints()

    #calculate bounding box
    mean_v = [0.0] * 3
    bounds_max_v = [0.0] * 3

    bounds = shapedatapoints.GetBounds()

    mean_v[0] = (bounds[0] + bounds[1])/2.0
    mean_v[1] = (bounds[2] + bounds[3])/2.0
    mean_v[2] = (bounds[4] + bounds[5])/2.0
    bounds_max_v[0] = max(bounds[0], bounds[1])
    bounds_max_v[1] = max(bounds[2], bounds[3])
    bounds_max_v[2] = max(bounds[4], bounds[5])

    shape_points = vtk_to_numpy(shapedatapoints.GetData())
    
    #centering points of the shape
    if mean_arr is None:
        mean_arr = np.array(mean_v)
    # print("Mean:", mean_arr)
    shape_points = shape_points - mean_arr

    #Computing scale factor if it is not provided
    if(scale_factor is None):
        bounds_max_arr = np.array(bounds_max_v)
        scale_factor = 1.0/np.linalg.norm(bounds_max_arr - mean_arr)

    #scale points of the shape by scale factor
    # print("Scale:", scale_factor)
    shape_points = np.multiply(shape_points, scale_factor)

    #assigning scaled points back to shape
    shapedatapoints.SetData(numpy_to_vtk(shape_points))

    return surf, mean_arr, scale_factor

def ScaleSurfT(surf, mean_arr=None, scale_factor=None, copy=True):
    if copy:
        # Perform a deep copy if needed (create a new tensor with the same data)
        surf = surf.clone()

    if mean_arr is None:
        mean_arr = surf.mean(dim=0)
    
    bounds_max_arr = surf.max(dim=0)[0]

    # Centering points of the shape
    surf = surf - mean_arr

    # Computing scale factor if it is not provided
    if scale_factor is None:
        scale_factor = 1.0 / (bounds_max_arr - mean_arr).norm()

    # Scale points of the shape by scale factor
    surf = surf * scale_factor

    return surf, mean_arr, scale_factor

class UnitSurfTransform:
    # This transform is used to make sure that the surface is in the unit cube
    def __init__(self, scale_factor=None):
        self.scale_factor = scale_factor

    def __call__(self, surf):
        if isinstance(surf, torch.Tensor):
            return GetUnitSurfT(surf)
        else:
            return GetUnitSurf(surf)

def GetUnitSurf(surf, mean_arr = None, scale_factor = None, copy=True):
    unit_surf, surf_mean, surf_scale = ScaleSurf(surf, mean_arr, scale_factor, copy)
    return unit_surf

def GetUnitSurfT(surf, mean_arr=None, scale_factor=None, copy=True):
    unit_surf, surf_mean, surf_scale = ScaleSurfT(surf, mean_arr, scale_factor, copy)
    return unit_surf

def data_to_tensor(path):
    data = nib.freesurfer.read_morph_data(path)
    data = data.byteswap().newbyteorder()
    data = torch.from_numpy(data).float()
    return data

def ReadSurf(fileName):

    fname, extension = os.path.splitext(fileName)    
    extension = extension.lower()    
    if extension == ".vtk":
        reader = vtk.vtkPolyDataReader()
        reader.SetFileName(fileName)
        reader.Update()
        surf = reader.GetOutput()
    elif extension == ".vtp":
        reader = vtk.vtkXMLPolyDataReader()
        reader.SetFileName(fileName)
        reader.Update()
        surf = reader.GetOutput()    
    elif extension == ".stl":
        reader = vtk.vtkSTLReader()
        reader.SetFileName(fileName)
        reader.Update()
        surf = reader.GetOutput()
    elif extension == ".off":
        reader = OFFReader()
        reader.SetFileName(fileName)
        reader.Update()
        surf = reader.GetOutput()
    elif extension == ".obj":
        if os.path.exists(fname + ".mtl"):
            obj_import = vtk.vtkOBJImporter()
            obj_import.SetFileName(fileName)
            obj_import.SetFileNameMTL(fname + ".mtl")
            textures_path = os.path.normpath(os.path.dirname(fname) + "/../images")
            if os.path.exists(textures_path):
                textures_path = os.path.normpath(fname.replace(os.path.basename(fname), ''))
                obj_import.SetTexturePath(textures_path)
            else:
                textures_path = os.path.normpath(fname.replace(os.path.basename(fname), ''))                
                obj_import.SetTexturePath(textures_path)
                    

            obj_import.Read()

            actors = obj_import.GetRenderer().GetActors()
            actors.InitTraversal()
            append = vtk.vtkAppendPolyData()

            for i in range(actors.GetNumberOfItems()):
                surfActor = actors.GetNextActor()
                append.AddInputData(surfActor.GetMapper().GetInputAsDataSet())
            
            append.Update()
            surf = append.GetOutput()
            
        else:
            reader = vtk.vtkOBJReader()
            reader.SetFileName(fileName)
            reader.Update()
            surf = reader.GetOutput()
    elif extension == '.gii':

        import nibabel as nib
        from fsl.data import gifti

        surf = nib.load(fileName)
        coords = surf.agg_data('pointset')
        triangles = surf.agg_data('triangle')

        points = vtk.vtkPoints()

        for c in coords:
            points.InsertNextPoint(c[0], c[1], c[2])

        cells = vtk.vtkCellArray()

        for t in triangles:
            t_vtk = vtk.vtkTriangle()
            t_vtk.GetPointIds().SetId(0, t[0])
            t_vtk.GetPointIds().SetId(1, t[1])
            t_vtk.GetPointIds().SetId(2, t[2])
            cells.InsertNextCell(t_vtk)

        surf = vtk.vtkPolyData()
        surf.SetPoints(points)
        surf.SetPolys(cells)
    else:
        raise Exception("File format not supported")
    
    return surf

def PolyDataToTensors_v_f(surf, device='cpu'):

    verts, faces, = PolyDataToNumpy_v_f(surf)
    
    verts = torch.tensor(verts).to(torch.float32).to(device)
    faces = torch.tensor(faces).to(torch.int64).to(device)
    
    return verts, faces

def PolyDataToNumpy_v_f(surf):

    vtk.vtkObject.GlobalWarningDisplayOff()
    verts = vtk_to_numpy(surf.GetPoints().GetData())
    faces = vtk_to_numpy(surf.GetPolys().GetData()).reshape(-1, 4)[:,1:]
    
    return verts, faces

def compute_verts(path):
    wm_vtk_path = os.path.join(path, f'lh.white.vtk')
    surf = ReadSurf(wm_vtk_path)

    transform = UnitSurfTransform()
    surf_norm = transform(surf)

    verts, faces = PolyDataToTensors_v_f(surf_norm)
    
    return verts

path1 = '/CMF/data/floda/abcd-data-release-5.1/data/sub-NDARINV021N0FLH/sub-NDARINV021N0FLH_ses-baselineYear1Arm1/surf/'
path2 = '/CMF/data/floda/abcd-data-release-5.1/data/sub-NDARINV028D3ELL/sub-NDARINV028D3ELL_ses-4YearFollowUpYArm1/surf/'

In [2]:
import ocnn
from ocnn.octree import Octree, Points
from typing import List, Optional


class PatchEmbed(torch.nn.Module):

  def __init__(self, in_channels: int = 3, dim: int = 96, num_down: int = 2,
               nempty: bool = True, **kwargs):
    super().__init__()
    self.num_stages = num_down
    self.delta_depth = -num_down
    channels = [int(dim * 2**i) for i in range(-self.num_stages, 1)]

    self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
        in_channels if i == 0 else channels[i], channels[i], kernel_size=[3],
        stride=1, nempty=nempty) for i in range(self.num_stages)])
    self.downsamples = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
        channels[i], channels[i+1], kernel_size=[2], stride=2, nempty=nempty)
        for i in range(self.num_stages)])
    self.proj = ocnn.modules.OctreeConvBnRelu(
        channels[-1], dim, kernel_size=[3], stride=1, nempty=nempty)

  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
      print("Data shape:", data.shape)
      for i in range(self.num_stages):
        depth_i = depth - i
        print("Depth:", depth_i)
        data = self.convs[i](data, octree, depth_i)
        print("coucou")
        data = self.downsamples[i](data, octree, depth_i)
      data = self.proj(data, octree, depth_i - 1)
      return data

class OctFormer(torch.nn.Module):

  def __init__(self, in_channels: int,
               channels: List[int] = [96, 192, 384, 384],
               num_blocks: List[int] = [2, 2, 18, 2],
               num_heads: List[int] = [6, 12, 24, 24],
               patch_size: int = 26, dilation: int = 4, drop_path: float = 0.5,
               nempty: bool = True, stem_down: int = 2, **kwargs):
    super().__init__()
    self.patch_size = patch_size
    self.dilation = dilation
    self.nempty = nempty
    self.num_stages = len(num_blocks)
    self.stem_down = stem_down
    drop_ratio = torch.linspace(0, drop_path, sum(num_blocks)).tolist()

    self.patch_embed = PatchEmbed(in_channels, channels[0], stem_down, nempty)
    # self.layers = torch.nn.ModuleList([OctFormerStage(
    #     dim=channels[i], num_heads=num_heads[i], patch_size=patch_size,
    #     drop_path=drop_ratio[sum(num_blocks[:i]):sum(num_blocks[:i+1])],
    #     dilation=dilation, nempty=nempty, num_blocks=num_blocks[i],)
    #     for i in range(self.num_stages)])
    # self.downsamples = torch.nn.ModuleList([Downsample(
    #     channels[i], channels[i + 1], kernel_size=[2],
    #     nempty=nempty) for i in range(self.num_stages - 1)])

  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
    data = self.patch_embed(data, octree, depth)
    depth = depth - self.stem_down   # current octree depth
    octree = OctreeT(octree, self.patch_size, self.dilation, self.nempty,
                     max_depth=depth, start_depth=depth-self.num_stages+1)
    features = {}
    for i in range(self.num_stages):
      depth_i = depth - i
      data = self.layers[i](data, octree, depth_i)
      features[depth_i] = data
      if i < self.num_stages - 1:
        data = self.downsamples[i](data, octree, depth_i)
    return features




class ClsHeader(torch.nn.Module):
  def __init__(self, out_channels: int, in_channels: int,
               nempty: bool = False, dropout: float = 0.5):
    super().__init__()
    self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
    self.cls_header = torch.nn.Sequential(
        ocnn.modules.FcBnRelu(in_channels, 256),
        torch.nn.Dropout(p=dropout),
        torch.nn.Linear(256, out_channels))

  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
    data = self.global_pool(data, octree, depth)
    logit = self.cls_header(data)
    return logit


class OctFormerCls(torch.nn.Module):

  def __init__(self, in_channels: int, out_channels: int,
               channels: List[int] = [96, 192, 384, 384],
               num_blocks: List[int] = [2, 2, 18, 2],
               num_heads: List[int] = [6, 12, 24, 24],
               patch_size: int = 32, dilation: int = 4,
               drop_path: float = 0.5, nempty: bool = True,
               stem_down: int = 2, head_drop: float = 0.5, **kwargs):
    super().__init__()
    self.backbone = OctFormer(
        in_channels, channels, num_blocks, num_heads, patch_size, dilation,
        drop_path, nempty, stem_down)
    # self.head = ClsHeader(
    #     out_channels, channels[-1], nempty, head_drop)
    # self.apply(self.init_weights)

  def init_weights(self, m):
    if isinstance(m, torch.nn.Linear):
      torch.nn.init.trunc_normal_(m.weight, std=0.02)
      if isinstance(m, torch.nn.Linear) and m.bias is not None:
        torch.nn.init.constant_(m.bias, 0)

  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
    features = self.backbone(data, octree, depth)
    curr_depth = min(features.keys())
    output = self.head(features[curr_depth], octree, curr_depth)
    return output

In [3]:
def GetUnitSurf(surf, mean_arr = None, scale_factor = None, copy=True):
    unit_surf, surf_mean, surf_scale = ScaleSurf(surf, mean_arr, scale_factor, copy)
    return unit_surf

def GetNormalsTensor(surf):
    normals = ComputeNormals(surf)
    normals = vtk_to_numpy(normals.GetPointData().GetNormals())
    return torch.tensor(normals, dtype=torch.double)

def ComputeNormals(surf):
    normals = vtk.vtkPolyDataNormals()
    normals.SetInputData(surf);
    normals.ComputeCellNormalsOn();
    normals.ComputePointNormalsOn();
    normals.SplittingOff();
    normals.Update()
    
    return normals.GetOutput()

In [4]:
from ocnn.dataset import CollateBatch
from thsolver import Dataset

# channel: 4
# feature: ND
# find_unused_parameters: False
# name: octformercls
# nempty: False
# nout: 40
# sync_bn: False
# use_checkpoint: False

# data = octree.get_input_feature('ND').to(torch.float)
# resnet = OctFormerCls(in_channels=4, out_channels=40) 
# resnet(data, octree=octree, depth=6).shape

path = '/work/floda/source/tools/octformer/data/ModelNet40/ModelNet40/airplane/train/airplane_0001.off'

surf1 = ReadSurf(path)
surf1 = GetUnitSurf(surf1)
V1, F1 = PolyDataToTensors_v_f(surf1)
N1 = GetNormalsTensor(surf1)

octree = Octree(6)
octree.build_octree(Points(V1, normals=N1))
octree.construct_all_neigh()
octree_feature = ocnn.modules.InputFeature('ND')
data = octree_feature(octree)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

in_channels = 4 
out_channels = 1

model = OctFormerCls(in_channels=in_channels, out_channels=out_channels)
model = model.to(device)

label = torch.tensor([0]).to(device) 
fake_batch = {
    'octree': octree.to(device),
    'data': data.to(device),
    'label': label
}
model.eval() 

with torch.no_grad():
    logits = model(fake_batch['data'], fake_batch['octree'], fake_batch['octree'].depth)



Data shape: torch.Size([4464, 4])
Depth: 6


AssertionError: The shape of input data is wrong.