In [1]:
import os

In [2]:
os.chdir("made-ml-demo-app-model")

In [3]:
import json
import pathlib
import pickle
import random

from IPython.display import display
import ipywidgets as widgets 
import torch
import trimesh
import umap
import numpy as np
from trimesh import creation
from torch_geometric.data  import Batch
from torch_geometric import utils
from pytransform3d.transformations import random_transform
from sklearn.preprocessing import LabelEncoder, StandardScaler
import plotly.graph_objects as go

from nn_model import SimpleClsLDGCN, BaseTransform
from train_param import DataParams

In [4]:
label_encoder_dump_filepath = pathlib.Path("inference-data", "label_encoder", "label_encoder.pickle")

In [5]:
with open(label_encoder_dump_filepath, "rb") as dump_file:
    label_encoder = pickle.load(dump_file)

In [6]:
path_to_model_config = pathlib.Path("inference-data", "model", "model_params.json")

In [7]:
with open(path_to_model_config, "r") as _model_params_file:
    model_params = json.load(_model_params_file)

In [8]:
cls_model = SimpleClsLDGCN(**model_params)

In [9]:
del model_params

In [10]:
path_to_checkpoint = pathlib.Path("inference-data", "model", "model_state.pth")

In [11]:
device = "cpu"

In [12]:
inference_transform = BaseTransform(DataParams().num_points)

In [13]:
cls_model.load_state_dict(torch.load(path_to_checkpoint, map_location=device))
cls_model.eval();

# 3D classification primitive

Model was trained for 3D primitive classification based on [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric).

The following classes were used for classification:

In [14]:
print(*label_encoder.classes_, sep="\n")

cone
cube
cylinder
plane
torus
uv_sphere


All primitives are 3D model. A primitive is transforming to point cloud before classification.  Points are randomly sampled on surface and their count is:

In [15]:
DataParams().num_points

512

[Thre repository with callification model](https://github.com/KernelA/made-ml-demo-app-model)

In [16]:
def cone_generator(transform):
    radius = random.uniform(0.2, 2)
    height = random.uniform(0.2, 4)
    return creation.cone(radius, height, transform=transform)

def cube_generator(transform):
    size = random.uniform(0.1, 2)
    return creation.box((size, size, size), transform=transform)

def cylinder_generator(transform):
    radius = random.uniform(0.2, 2)
    height = random.uniform(1, 3)
    return creation.cylinder(radius, height, transform=transform)

def plane_generator(transform):
    vertices = np.array(
        (
            (1, 1, 0), (-1, 1, 0), (-1, -1, 0), (1, -1, 0)
        )
    )
    faces = np.array(
        (
            (0, 1, 2),
            (2, 3, 0),
            (2, 1, 0),
            (0, 3, 2)
        )
    )
    plane = trimesh.Trimesh(vertices=vertices, faces=faces)
    plane.apply_transform(transform)
    return plane

def sphere_generator(transform):
    radius = random.uniform(1, 10)
    return creation.icosphere(radius=radius, subdivisions=2)

def torus_generator(transform):
    torus_radius = random.uniform(0.1, 5)
    dist = random.uniform(torus_radius + 0.1,  torus_radius + 10)
    circle = trimesh.path.creation.circle(torus_radius, (dist, 0))
    return creation.revolve(circle.discrete[0], transform=transform)

In [17]:
translate_mapping = {
    "cone": "Cone",
    "cube": "Cube",
    "cylinder": "Cylinder",
    "plane": "Plane",
    "uv_sphere": "Sphere",
    "torus": "Torus"
}

generators = {
    "Cone": cone_generator,
    "Cube": cube_generator,
    "Cylinder": cylinder_generator,
    "Plane": plane_generator,
    "Sphere": sphere_generator,
    "Torus": torus_generator
}

In [18]:
labels = tuple(generators.keys())

label_widget = widgets.Label("Choose a 3D primitive for classification:")

model_list_widget = widgets.ToggleButtons(
    options=labels,
    description='',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
)

del labels

In [19]:
# material template in Viola does not render dropdown list. ToggleButtons choosed instead
primitive_choosing = widgets.GridBox([model_list_widget])

In [20]:
header_widget = widgets.VBox([label_widget, primitive_choosing])

In [21]:
score = None
output_plot_widget = go.FigureWidget(layout=go.Layout())
output_plot_widget.update_layout(title="Model confidence")
output_mesh_widget = widgets.Output()
output_mesh_widget.layout.border = "solid"
output_mesh_widget.layout.width = "60%"
output_widget = widgets.HBox([output_plot_widget, output_mesh_widget])
output_widget.layout.height = "600px"

In [22]:
RED_COLOR = (255, 0, 0)
GREEN_COLOR = (0, 255, 0)

In [23]:
@torch.no_grad()
def classify_object(mesh, model, transform):
    data = utils.from_trimesh(mesh)
    batch = Batch.from_data_list([transform(data)])
    prediction_score = model.forward(batch)
    return prediction_score

In [24]:
def generate_and_calassify_primitive(label,
                       generator,
                       model,
                       inference_transform,
                       label_encoder,
                       translate_mapping: dict):
    random_transformation = random_transform()
    # only rotate
    random_transformation[:-1, -1] = 0
    mesh = generator[label](random_transformation)
    raw_score = classify_object(mesh, model, inference_transform)[0]
    pred_label = label_encoder.inverse_transform([raw_score.argmax()])[0]
    
    if label == translate_mapping[pred_label]:
        color = GREEN_COLOR
    else:
        color = RED_COLOR

    mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh, vertex_colors=color)
    global score
    score = torch.exp(raw_score)
    with output_mesh_widget:
        display(mesh.show(viewer="notebook", smooth=True))

In [25]:
def plot_score(label_encoder, translate_mapping):
    x_labels = tuple(map(lambda x: translate_mapping[x], label_encoder.classes_))
    output_plot_widget.data = []
    output_plot_widget.add_trace(
        go.Bar(x=x_labels, y=score)
    )

In [26]:
btn_title = "Classify"
classify_button = widgets.Button(
    description=btn_title,
    disabled=False,
    tooltip="Classify choosen primitive",
    button_style=""
)

classify_button.layout.width = "auto"

In [27]:
def calassify_callback(button):
    button.disabled = True
    output_mesh_widget.clear_output()
    button.description = "Classification..."
    generate_and_calassify_primitive(model_list_widget.value
                                     , generator=generators
                                     , model=cls_model
                                     , inference_transform=inference_transform
                                     , label_encoder=label_encoder
                                     , translate_mapping=translate_mapping)
    plot_score(label_encoder, translate_mapping)
    classify_button.description = btn_title
    button.disabled = False

In [28]:
classify_button.on_click(calassify_callback)

# Classification of randomly generated 3D primitive 

You need a primitive from list below and press button "Classify primitive". You can see histogram with model confidence on the side. Original 3D model on the right side. It is possible some problems with loading widgets. Please try to reload page or rerurn example.

In [29]:
widgets.VBox([header_widget, classify_button, output_widget])

VBox(children=(VBox(children=(Label(value='Choose a 3D primitive for classification:'), GridBox(children=(Togg…

# Global feature visualization with UMAP 

Generating predefined number of examples for each class. For each example the model calculates a global feature vector. These vectors reduced to 2 component and it plots on plane with UMAP.

When number of examples too large computation time may be large because all computations performed on CPU.

In [30]:
@torch.no_grad()
def get_features(mesh, model, batch):
    batch = Batch.from_data_list(batch)
    features = model.global_feature(batch)
    return features

In [31]:
sample_umap_widget = widgets.IntSlider(
    value=20,
    min=5,
    max=100,
    step=1,
    description="",
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
label_sample_umap_widget = widgets.Label("Numer of examples for each class:") 
header_umap_widget = widgets.HBox([label_sample_umap_widget, sample_umap_widget])

In [32]:
progress_widget = widgets.FloatProgress(value=0.0, min=0.0, max=1.0)

In [33]:
umap_vis_button = widgets.Button(
    description="Plot features",
    disabled=False,
    tooltip="Plot features",
    button_style=""
)
umap_vis_button.layout.width = "auto"
l = widgets.jslink((umap_vis_button, "disabled"), (classify_button, "disabled"))

In [34]:
fig = go.FigureWidget(layout=go.Layout());
fig.update_layout(title="Feature plot with UMAP");

In [35]:
def generate_umap(num_samples: int, primitive_generator, model, inference_transform):
    fig.data = []
    num_sample = num_samples
    features = []
    labels = []
    
    progress_widget.value = 0.1
    
    total = len(primitive_generator) * num_sample
    progress = 0
    
    for label in primitive_generator:
        batch = []
        for i in range(num_sample):
            random_transformation = random_transform()
            # only rotate
            random_transformation[:-1, -1] = 0
            mesh = primitive_generator[label](random_transformation)
            batch.append(inference_transform(utils.from_trimesh(mesh)))
            progress += 1
            progress_widget.value = progress / total * 0.5
        
        features.extend(get_features(mesh, cls_model, batch).cpu().numpy())
        labels.extend(label for _ in range(num_sample))
    
    progress_widget.value = 0.5
        
    encoder = LabelEncoder()
    colors = encoder.fit_transform(labels)
    reducer = umap.UMAP(learning_rate=0.8, verbose=False, n_epochs=200)
    embedding = reducer.fit_transform(StandardScaler().fit_transform(features))
    progress_widget.value = 0.75
    fig.add_trace(
        go.Scatter(
            x=embedding[:, 0],
            y=embedding[:, 1],
            text=labels,
            mode="markers",
            marker=dict(color=colors)
            )
    )
    progress_widget.value = 1
    
def umap_gen_callback(button):
    button.disabled = True
    new_title = "UMAP projection..."
    button.description = new_title
    classify_button.description = new_title
    generate_umap(sample_umap_widget.value, generators, cls_model, inference_transform)
    classify_button.description = btn_title
    button.description = "Plot features"
    button.disabled = False

In [36]:
umap_vis_button.on_click(umap_gen_callback)

In [37]:
umap_layout = widgets.VBox([header_umap_widget, umap_vis_button, progress_widget, fig])
umap_layout

VBox(children=(HBox(children=(Label(value='Numer of examples for each class:'), IntSlider(value=20, continuous…

# Applications

This exampel demonstrate application of deep learning methods for 3D data processing also known as 3D ML.

It can be used for follow applications:

* Searching among 3D models
* 3D object classification from devices such as RGB-D or Lidar
* Clustering 3D models
* and others

[Some info about it (Habr)](https://habr.com/ru/company/itmai/blog/503358/)