In [None]:
# Good resources
# https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/training.ipynb
# https://towardsdatascience.com/generating-3d-models-with-polygen-and-pytorch-4895f3f61a2e

# https://pytorch3d.org/tutorials/render_textured_meshes

In [19]:
import numpy as np
import pandas as pd
import torch

In [2]:
def load_obj(filename):
  """Load vertices from .obj wavefront format file."""
  vertices = []
  with open(filename, 'r') as mesh:
    for line in mesh:
      data = line.split()
      if len(data) > 0 and data[0] == 'v':
        vertices.append(data[1:])
  return np.array(vertices, dtype=np.float32)

In [5]:
root = 'data'
data_path = f'{root}/3D_scans'
label_path = f'{root}/labels'

In [11]:
fname = '0EJBIPTC'
verts_upper = load_obj(f'{data_path}/{fname}/{fname}_lower.obj')
verts_lower = load_obj(f'{data_path}/{fname}/{fname}_upper.obj')
print(verts_upper)

[[   1.1006801    17.167776   -102.11712       0.50196075    0.50196075
     0.50196075]
 [   5.3306537   -20.593128    -91.18822       0.50196075    0.50196075
     0.50196075]
 [ -16.466866     -8.0680485   -90.530266      0.50196075    0.50196075
     0.50196075]
 ...
 [  -2.9833782   -18.776342    -82.41675       0.50196075    0.50196075
     0.50196075]
 [  19.891361      5.172693    -90.22645       0.50196075    0.50196075
     0.50196075]
 [ -22.399826     13.0559225   -92.015         0.50196075    0.50196075
     0.50196075]]


In [15]:
# Lexicoghraphic sorting (order vertices for size reduction later)
def lexi_sort(verts):
    keys = [verts[..., i] for i in range(verts.shape[-1])]
    idxs = np.lexsort(keys)
    verts = verts[idxs]
    return verts
    
verts_upper = lexi_sort(verts_upper)
verts_lower = lexi_sort(verts_lower)

In [17]:
# Normalize vertex coordinates and quantize. This converts to 8-bit values
# This approach used in Pixel RNNs and WaveNet
def quantize(verts):
    # normalize vertices to range [0.0, 1.0]
    lims = [-1.0, 1.0]
    norm_verts = (verts - lims[0]) / (lims[1] - lims[0])
    
    # quantize vertices to integers in range [0, 255]
    n_vals = 2 ** 8
    delta = 1. / n_vals
    return np.maximum(np.minimum((norm_verts // delta), n_vals - 1), 0).astype(np.int32)

upper_quant = quantize(verts_upper)
lower_quant = quantize(verts_lower)

## Attempt 2 (Using PyTorch3D)

In [22]:
%pip install "git+https://github.com/facebookresearch/pytorch3d.git"

^C
Note: you may need to restart the kernel to use updated packages.


ERROR: For req: pytorch3d==0.6.1. Invalid script entry point: <ExportEntry pytorch3d_implicitron_runner = projects.implicitron_trainer.experiment:None []> - A callable suffix is required. Cf https://packaging.python.org/specifications/entry-points/#use-for-scripts for more information.
You should consider upgrading via the 'c:\users\rjsmi\appdata\local\programs\python\python38\python.exe -m pip install --upgrade pip' command.


Collecting git+https://github.com/facebookresearch/pytorch3d.git
  Cloning https://github.com/facebookresearch/pytorch3d.git to c:\users\rjsmi\appdata\local\temp\pip-req-build-bjflp1n0
Collecting fvcore
  Downloading fvcore-0.1.5.post20220305.tar.gz (50 kB)
Collecting iopath
  Downloading iopath-0.1.9-py3-none-any.whl (27 kB)
Collecting yacs>=0.1.6
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting tabulate
  Downloading tabulate-0.8.9-py3-none-any.whl (25 kB)
Collecting portalocker
  Downloading portalocker-2.4.0-py2.py3-none-any.whl (16 kB)
Building wheels for collected packages: pytorch3d, fvcore
  Building wheel for pytorch3d (setup.py): started
  Building wheel for pytorch3d (setup.py): still running...
  Building wheel for pytorch3d (setup.py): still running...
  Building wheel for pytorch3d (setup.py): still running...
  Building wheel for pytorch3d (setup.py): finished with status 'done'
  Created wheel for pytorch3d: filename=pytorch3d-0.6.1-cp38-cp38-win_amd64.whl s

ERROR: For req: pytorch3d==0.6.1. Invalid script entry point: <ExportEntry pytorch3d_implicitron_runner = projects.implicitron_trainer.experiment:None []> - A callable suffix is required. Cf https://packaging.python.org/specifications/entry-points/#use-for-scripts for more information.
You should consider upgrading via the 'C:\Users\rjsmi\AppData\Local\Programs\Python\Python38\python.exe -m pip install --upgrade pip' command.


In [23]:
from pytorch3d.io import load_obj

In [25]:
verts_l, faces_l, aux_l = load_obj(f'{data_path}/{fname}/{fname}_lower.obj')
verts_u, faces_u, aux_u = load_obj(f'{data_path}/{fname}/{fname}_upper.obj')



# ------ Dataset ------

Docs on PyTorch3D dataloaders can be found [here](https://pytorch3d.org/docs/datasets)

In [None]:
# This is a work in progress. Currently, this is the dataset from MeshDataset on that GitHub page we were looking at.

from torch.utils.data import Dataset

class Mesh_Dataset(Dataset):
    def __init__(self, data_list_path, num_classes=15, patch_size=7000):
        """
        Args:
            h5_path (string): Path to the txt file with h5 files.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_list = pd.read_csv(data_list_path, header=None)
        self.num_classes = num_classes
        self.patch_size = patch_size

    def __len__(self):
        return self.data_list.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        i_mesh = self.data_list.iloc[idx][0] #vtk file name

        # read vtk
        mesh = load(i_mesh)
        labels = mesh.getCellArray('Label').astype('int32').reshape(-1, 1)

        #create one-hot map
#        label_map = np.zeros([mesh.cells.shape[0], self.num_classes], dtype='int32')
#        label_map = np.eye(self.num_classes)[labels]
#        label_map = label_map.reshape([len(labels), self.num_classes])

        # move mesh to origin
        N = mesh.NCells()
        points = vtk2numpy(mesh.polydata().GetPoints().GetData())
        ids = vtk2numpy(mesh.polydata().GetPolys().GetData()).reshape((N, -1))[:,1:]
        cells = points[ids].reshape(N, 9).astype(dtype='float32')

        mean_cell_centers = mesh.centerOfMass()
        cells[:, 0:3] -= mean_cell_centers[0:3]
        cells[:, 3:6] -= mean_cell_centers[0:3]
        cells[:, 6:9] -= mean_cell_centers[0:3]

        # customized normal calculation; the vtk/vedo build-in function will change number of points
        v1 = np.zeros([mesh.NCells(), 3], dtype='float32')
        v2 = np.zeros([mesh.NCells(), 3], dtype='float32')
        v1[:, 0] = cells[:, 0] - cells[:, 3]
        v1[:, 1] = cells[:, 1] - cells[:, 4]
        v1[:, 2] = cells[:, 2] - cells[:, 5]
        v2[:, 0] = cells[:, 3] - cells[:, 6]
        v2[:, 1] = cells[:, 4] - cells[:, 7]
        v2[:, 2] = cells[:, 5] - cells[:, 8]
        mesh_normals = np.cross(v1, v2)
        mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
        mesh_normals[:, 0] /= mesh_normal_length[:]
        mesh_normals[:, 1] /= mesh_normal_length[:]
        mesh_normals[:, 2] /= mesh_normal_length[:]
        mesh.addCellArray(mesh_normals, 'Normal')

        # preprae input and make copies of original data
        points = mesh.points().copy()
        points[:, 0:3] -= mean_cell_centers[0:3]
        normals = mesh.getCellArray('Normal').copy() # need to copy, they use the same memory address
        barycenters = mesh.cellCenters() # don't need to copy
        barycenters -= mean_cell_centers[0:3]

        #normalized data
        maxs = points.max(axis=0)
        mins = points.min(axis=0)
        means = points.mean(axis=0)
        stds = points.std(axis=0)
        nmeans = normals.mean(axis=0)
        nstds = normals.std(axis=0)

        for i in range(3):
            cells[:, i] = (cells[:, i] - means[i]) / stds[i] #point 1
            cells[:, i+3] = (cells[:, i+3] - means[i]) / stds[i] #point 2
            cells[:, i+6] = (cells[:, i+6] - means[i]) / stds[i] #point 3
            barycenters[:,i] = (barycenters[:,i] - mins[i]) / (maxs[i]-mins[i])
            normals[:,i] = (normals[:,i] - nmeans[i]) / nstds[i]

        X = np.column_stack((cells, barycenters, normals))
        Y = labels

        # initialize batch of input and label
        X_train = np.zeros([self.patch_size, X.shape[1]], dtype='float32')
        Y_train = np.zeros([self.patch_size, Y.shape[1]], dtype='int32')
        S1 = np.zeros([self.patch_size, self.patch_size], dtype='float32')
        S2 = np.zeros([self.patch_size, self.patch_size], dtype='float32')

        # calculate number of valid cells (tooth instead of gingiva)
        positive_idx = np.argwhere(labels>0)[:, 0] #tooth idx
        negative_idx = np.argwhere(labels==0)[:, 0] # gingiva idx

        num_positive = len(positive_idx) # number of selected tooth cells

        if num_positive > self.patch_size: # all positive_idx in this patch
            positive_selected_idx = np.random.choice(positive_idx, size=self.patch_size, replace=False)
            selected_idx = positive_selected_idx
        else:   # patch contains all positive_idx and some negative_idx
            num_negative = self.patch_size - num_positive # number of selected gingiva cells
            positive_selected_idx = np.random.choice(positive_idx, size=num_positive, replace=False)
            negative_selected_idx = np.random.choice(negative_idx, size=num_negative, replace=False)
            selected_idx = np.concatenate((positive_selected_idx, negative_selected_idx))

        selected_idx = np.sort(selected_idx, axis=None)

        X_train[:] = X[selected_idx, :]
        Y_train[:] = Y[selected_idx, :]

        # output to visualize
#        mesh2 = Easy_Mesh()
#        mesh2.cells = X_train[:, 0:9]
#        mesh2.update_cell_ids_and_points()
#        mesh2.cell_attributes['Normal'] = X_train[:, 12:15]
#        mesh2.cell_attributes['Label'] = Y_train
#        mesh2.to_vtp('tmp.vtp')
        if  torch.cuda.is_available():
            TX = torch.as_tensor(X_train[:, 9:12], device='cuda')
            TD = torch.cdist(TX, TX)
            D = TD.cpu().numpy()
        else:
            D = distance_matrix(X_train[:, 9:12], X_train[:, 9:12])

        S1[D<0.1] = 1.0
        S1 = S1 / np.dot(np.sum(S1, axis=1, keepdims=True), np.ones((1, self.patch_size)))

        S2[D<0.2] = 1.0
        S2 = S2 / np.dot(np.sum(S2, axis=1, keepdims=True), np.ones((1, self.patch_size)))

        X_train = X_train.transpose(1, 0)
        Y_train = Y_train.transpose(1, 0)

        sample = {'cells': torch.from_numpy(X_train), 'labels': torch.from_numpy(Y_train),
                  'A_S': torch.from_numpy(S1), 'A_L': torch.from_numpy(S2)}

        return sample