In [1]:
import os

In [2]:
os.chdir("made-ml-demo-app-model")

In [3]:
import json
import pathlib
import pickle
import random

import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import torch
import trimesh
import numpy as np
from trimesh import creation
from torch_geometric.data  import Batch
from torch_geometric import utils
from pytransform3d.transformations import random_transform

from point_cloud_cls import SimpleClsLDGCN, BaseTransform
from train_param import DataParams

In [4]:
label_encoder_dump_filepath = pathlib.Path("inference-data", "label_encoder", "label_encoder.pickle")

In [5]:
with open(label_encoder_dump_filepath, "rb") as dump_file:
    label_encoder = pickle.load(dump_file)

In [6]:
path_to_model_config = pathlib.Path("inference-data", "model", "model_params.json")

In [7]:
with open(path_to_model_config, "r") as _model_params_file:
    model_params = json.load(_model_params_file)

In [8]:
cls_model = SimpleClsLDGCN(**model_params)

In [9]:
del model_params

In [10]:
path_to_checkpoint = pathlib.Path("inference-data", "model", "model_state.pth")

In [11]:
device = "cpu"

In [12]:
inference_transform = BaseTransform(DataParams().num_points)

In [13]:
cls_model.load_state_dict(torch.load(path_to_checkpoint, map_location=device))
cls_model.eval();

In [14]:
label_encoder.classes_

array(['cone', 'cube', 'cylinder', 'plane', 'torus', 'uv_sphere'],
      dtype='<U9')

In [15]:
def cone_generator(transform):
    radius = random.uniform(0.2, 2)
    height = random.uniform(1, 3)
    return creation.cone(radius, height, transform=transform)

def cube_generator(transform):
    size = random.uniform(0.1, 2)
    return creation.box((size, size, size), transform=transform)

def cylinder_generator(transform):
    radius = random.uniform(0.2, 2)
    height = random.uniform(1, 3)
    return creation.cylinder(radius, height, transform=transform)

def plane_generator(transform):
    vertices = np.array(
        (
            (1, 1, 0), (-1, 1, 0), (-1, -1, 0), (1, -1, 0)
        )
    )
    faces = np.array(
        (
            (0, 1, 2),
            (2, 3, 0)
        )
    )
    plane = trimesh.Trimesh(vertices=vertices, faces=faces)
    plane.apply_transform(transform)
    return plane

def sphere_generator(transform):
    radius = random.uniform(1, 10)
    return creation.icosphere(radius=radius, subdivisions=2)

In [23]:
translate_mapping = {
    "cone": "конус",
    "cube": "куб",
    "cylinder": "цилиндр",
    "plane": "плоскость",
    "uv_sphere": "сфера"
}

generators = {
    "конус": cone_generator,
    "куб": cube_generator,
    "цилиндр": cylinder_generator,
    "плоскость": plane_generator,
    "сфера": sphere_generator
}

In [24]:
model_list = widgets.Dropdown(
    options=tuple(generators.keys()),
    value=tuple(generators.keys())[0],
    description='Выбор 3D примитива:',
    disabled=False,
    layout={'width': 'max-content'}
)

In [25]:
def classify_object(mesh, model, transform):
    data = utils.from_trimesh(mesh)
    batch = Batch.from_data_list([transform(data)])
    prediction_score = model.forward(batch)
    return prediction_score

In [33]:
@torch.no_grad()
def generate_primitive(label, 
                       generator,
                       model,
                       inf_transform,
                       label_encoder,
                       translate_mapping: dict):
    rotate = random_transform()
    rotate[:-1, -1] = 0
    mesh = generator[label](rotate)
    score = classify_object(mesh, model, inf_transform)
    label = label_encoder.inverse_transform(score.argmax(dim=1))
    print(translate_mapping[label[0]])
    return mesh.show(viewer="notebook")

In [34]:
interact_manual(generate_primitive
                , label=model_list
                , generator=fixed(generators)
                , model=fixed(cls_model)
                , inf_transform=fixed(inference_transform)
                , label_encoder=fixed(label_encoder)
                , translate_mapping=fixed(translate_mapping))

interactive(children=(Dropdown(description='Выбор 3D примитива:', index=2, layout=Layout(width='max-content'),…

<function __main__.generate_primitive(label, generator, model, inf_transform, label_encoder, translate_mapping: dict)>