In [1]:
import os
import sys

import json

In [2]:
import catboost
import pandas as pd
import numpy as np
import trimesh
import matplotlib.cm as cm

from tqdm import tqdm


In [3]:
from IPython.display import display, JSON

In [None]:
from datetime import datetime
datetime.today().strftime('%Y-%m-%d %H:%M:%S')

'2025-11-07 13:45:28'

In [5]:
catboost.__version__

'1.1.1'

In [6]:
SCRIPT_DIR = os.path.dirname(os.path.abspath(''))
sys.path.append(SCRIPT_DIR)
sys.path.append(os.path.join(SCRIPT_DIR, 'lib'))

In [7]:
pd.set_option('display.max_columns', 200)
pd.set_option('display.max_rows', 1000)
pd.set_option('display.max_colwidth', 1000)

In [8]:
from txparser.consts import (
    axes_xyz as axes,
    baseline_features,
)

from txparser.utils import (
    read_json,
)

## Collect index

In [9]:
# train_path = os.path.join(SCRIPT_DIR, 'baseline', 'data', 'train.csv')
val_path = os.path.join(SCRIPT_DIR, 'baseline', 'data', 'val.csv')

In [10]:
val_df = pd.read_csv(val_path)
# train_df = pd.read_csv(train_path)

In [11]:
features_names = baseline_features[1:]

## Load model

In [None]:
def make_scene(axis=True, scale=20):
    ambient = trimesh.scene.lighting.Light(
        color=np.array([255, 255, 255]),
        intensity=1.0
    )

    s = trimesh.Scene(lights=[ambient])

    if axis:
        axis = trimesh.creation.axis()
        axis.apply_scale(scale)
        s.add_geometry(axis)

    return s

def plot_mesh(scene, m, color):
    m.visual.vertex_colors = color
    scene.add_geometry(m)


def plot_meshes(scene, ms, color=None):
    colors = cm.rainbow(np.linspace(0, 1, len(ms)*2)) * 255
    for i, m in enumerate(ms):
        if not m:
            continue
        plot_mesh(scene, m, color if color is not None else colors[i])


def plot_cloud(scene, points, color, scale=1):
    if len(points) <= 0:
        return
    m = trimesh.points.PointCloud(points)
    m.apply_scale(scale)
    m.colors = [color for _ in range(len(points))]
    scene.add_geometry(m)


def plot_points(scene, points, radius=0.125, color=None):
    colors = cm.rainbow(np.linspace(0, 1, len(points)*2)) * 255
    for i, pt in enumerate(points):
        m = trimesh.primitives.Sphere(
            radius=max(abs(radius[i]) if hasattr(radius, "__len__") else radius, 0.01),
            center=pt
        )
        m.visual.vertex_colors = color if color is not None else colors[i if i % 2 == 1 else (len(points)*2 - 1)]
        scene.add_geometry(m)

In [13]:
def make_template(prefix='tooth'):
    result = { 'file_id': None, 'norm': 1 }

    for d in axes:
        result.update({f'{prefix}_{idx}_{d}': None for idx in range(1, 33)})
        result.update({f'{prefix}_{idx}_{d}': None for idx in range(1, 33)})
        result.update({f'{prefix}_{idx}_{d}': None for idx in range(1, 33)})

    return result


def make_row(norm, teeth, prefix='tooth'):
    row = make_template()

    row['norm'] = norm

    for idx, tooth in teeth.items():
        center = tooth
        for i, d in enumerate(axes):
            row[f'{prefix}_{idx}_{d}'] = center[i]

    return row


def make_item(row, prefix='tooth'):
    item = {}

    for idx in range(1, 33):
        if f'{prefix}_{idx}_0' in row:
            item[str(idx)] = [row[f'{prefix}_{idx}_{d}'] for d in range(3)]

    return item

In [14]:
lo_row = val_df[val_df['jaw'] == 1].sample().to_dict('records')[0]
up_row = val_df[val_df['jaw'] == 0].sample().to_dict('records')[0]

In [15]:
# up_row = all_df[(all_df['jaw'] == 0) & (all_df['file_id'] == '5514e226-377e-44f1-b7db-8fd647c3cc6d')].to_dict('records')[0]
# lo_row = all_df[(all_df['jaw'] == 1) & (all_df['file_id'] == 'd4b3b195-6c11-4b0d-818f-acd343ed1b45')].to_dict('records')[0]


In [16]:
lo_row['file_id'], up_row['file_id']

('781cadab-8848-47d7-b92b-640e1163ee3c',
 '69975f71-f03d-44a4-883f-a954d61fd776')

In [17]:
norm = lo_row['norm']

In [18]:
lo_features = [1] + [lo_row[x] for x in features_names]
up_features = [0] + [up_row[x] for x in features_names]

In [19]:
lo_row['norm']

42.03857915848842

In [20]:
# JSON(lo_features)

In [None]:
def teeth_jaw(teeth, jaw=0):
    gap = 0 if jaw == 0 else 16

    return {str(idx-gap): teeth[str(idx)] for idx in range(1+gap, 17+gap) if str(idx) in teeth}

In [None]:
def teeth_features(features_names, jaw_teeth):
    result = []
    features = {}

    for idx in range(1, 17):
        if str(idx) not in jaw_teeth:
            features[f'center_{idx}_0'] = None
            features[f'center_{idx}_1'] = None
            features[f'center_{idx}_2'] = None
            continue

        for i in range(3):
            features[f'center_{idx}_{i}'] = jaw_teeth[str(idx)][i]

    return [features[x] for x in features_names]

In [None]:
def root_coords(row):

    result = {}

    for i in range(1, 17):

        if pd.isna(row[f'origin_{i}_0']):
            continue

        result[i] = [
            row[f'origin_{i}_0'],
            row[f'origin_{i}_1'],
            row[f'origin_{i}_2'],
        ]

    return result

In [24]:
def center_coords(row):

    result = {}

    for i in range(1, 17):

        if pd.isna(row[f'center_{i}_0']):
            continue

        result[i] = [
            row[f'center_{i}_0'],
            row[f'center_{i}_1'],
            row[f'center_{i}_2'],
        ]

    return result

In [None]:
def load_model(idx, axis, model_file='root.model'):
    model = catboost.CatBoostRegressor()
    path = os.path.join(SCRIPT_DIR, 'models', str(idx) + '-' + str(axis) + '-' + model_file)

    model_info = read_json(os.path.join(SCRIPT_DIR, 'models', str(idx) + '-' + str(axis) + '-' + model_file + '.json'))
    model.load_model(os.path.join(SCRIPT_DIR, 'models', str(idx) + '-' + str(axis) + '-' + model_file))

    return model, model_info


def load_models():
    models = {}

    for idx in range(1, 17):
        models[idx] = {}
        for axis in range(0, 3):
            models[idx][axis] = load_model(idx, axis)
    return models

In [None]:
def make_predict(models, features):
    result = {}

    for i in range(1, 17):

        result[i] = [
            models[i][0][0].predict(features),
            models[i][1][0].predict(features),
            models[i][2][0].predict(features)
        ]
    return result

In [27]:
def cut(roots, jaw):
    return {k: v for k, v in roots.items() if k in jaw}

In [28]:
models = load_models()

In [29]:
lo_predict = make_predict(models, lo_features)
up_predict = make_predict(models, up_features)

In [30]:
lo_gt = root_coords(lo_row)
up_gt = root_coords(up_row)

In [31]:
lo_centers = center_coords(lo_row)
up_centers = center_coords(up_row)

In [32]:
lo_predict = cut(lo_predict, lo_centers)
up_predict = cut(up_predict, up_centers)

In [33]:
sorted([int(x) for x in lo_centers.keys()])

[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

In [34]:
sorted([int(x) for x in up_centers.keys()])

[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

In [35]:
s = make_scene(scale=1/2)

plot_points(s, lo_predict.values(), radius=0.025)
plot_points(s, lo_centers.values(), radius=0.025)
plot_points(s, lo_gt.values(),      radius=0.02)

s.show(axes=1, height=1200, line_settings={'point_size': 1})

In [36]:
s = make_scene(scale=1/2)

plot_points(s, up_predict.values(), radius=0.025)
plot_points(s, up_centers.values(), radius=0.025)
plot_points(s, up_gt.values(),      radius=0.02)

s.show(axes=1, height=1200, line_settings={'point_size': 1})