In [1]:
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2020_protein/training"

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import open3d as o3d
import numpy as np
import os
import MinkowskiEngine as ME
import torch_geometric.io as tio

In [3]:
idx = 1
prot1_path = f"/home/ioannis/Desktop/programming/data/SHREC/SHREC2020_protein/training/{idx}/triangulatedSurf.off"
prot1_gt_path = f"/home/ioannis/Desktop/programming/data/SHREC/SHREC2020_protein/training/{idx}/vertexMap.txt"
mesh = o3d.io.read_triangle_mesh(prot1_path)
mesh.compute_vertex_normals()
vertices = np.asarray(mesh.vertices)
with open(prot1_gt_path) as f:
    labels = np.array(list(map(float, f.readlines())))
colors = 0.5 * np.ones((len(labels), 3))
colors[labels==1.0, 0] = 1.0
mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
o3d.visualization.draw_geometries([mesh])

In [10]:
def read_labels(path):
    
    with open(path) as f:
        labels = np.array(list(map(float, f.readlines())))
    return labels

class ProteinDataset(Dataset):
    
    def __init__(self, path, transform=[]):
        self._path = path
        self._mesh_name = "triangulatedSurf.off"
        self._label_name = "vertexMap.txt"
        
        self._idxs = os.listdir(path)
        
        self.transform = transform
        
    def __len__(self):
        return len(self._idxs)
    
    def __getitem__(self, idx):
        
        # counting starts from 1 (not zero)
        idx = self._idxs[idx]
        
        # configuring paths
        parent_folder = os.path.join(self._path, str(idx))
        mesh_path = os.path.join(parent_folder, self._mesh_name)
        label_path = os.path.join(parent_folder, self._label_name)
        
        # read mesh
        mesh = o3d.io.read_triangle_mesh(mesh_path)
        vertices = np.asarray(mesh.vertices)
        labels = read_labels(label_path)
        
        sample = {
            'x' : vertices,
            'y' : labels
        }
            
        for t in self.transform:
            sample = t(sample)
            
        return sample

In [21]:
import numpy as np
class Normalize():
    
    def __call__(self, d):
        
        verts = d['x']
        l = np.sqrt((verts * verts).sum(-1)).max()
        
        d['x'] /= l
        
        return d
        
from torchsparse.utils.quantize import sparse_quantize

class Quantize():
    
    def __init__(self, voxel_size = 0.05):
        
        self.voxel_size = voxel_size
        
    def __call__(self, d):
        
        points = d['x']
        
        labels = d['y']
        
        coordinates, indices = sparse_quantize(points, voxel_size=self.voxel_size, return_index=True)

        points = points[indices]
        labels = labels[indices]
        
        return {'coords':coordinates, 'x':points, 'y':labels}
        

In [22]:
def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['coords'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'] for d in list_data],
        dtype = torch.float32
    )
    
    return {
        "coordinates": coordinates, 
        "features"   : features,
        "labels"     : labels
    }

In [48]:
dataset = ProteinDataset(path, transform=[Normalize(), Quantize(0.001)] )

In [50]:
for i in range(10):
    s = dataset[i]
    points = s["x"]
    labels = s["y"]
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    color = 0.5 * np.ones((len(points), 3))
    color[labels == 1, 0] = 1
    pcd.colors = o3d.utility.Vector3dVector(color)
    o3d.visualization.draw_geometries([pcd])