# Feature engineering examples

## Introduction
This notebook illustrates some feature engineering capabilities provided by plaid-ops. 

In [None]:
import os

from datasets import load_dataset
from IPython.display import Image as IPyImage
from IPython.display import display
from PIL import Image as PILImage
from plaid.bridges.huggingface_bridge import (
    huggingface_dataset_to_plaid,
    huggingface_description_to_problem_definition,
)

from plaid_ops.common.visualization import plot_field
from plaid_ops.mesh.feature_engineering import (
    compute_sdf,
    update_dataset_with_sdf,
    update_sample_with_sdf,
)

hf_dataset = load_dataset(
    "PLAID-datasets/2D_Multiscale_Hyperelasticity", split="all_samples"
)
pb_def = huggingface_description_to_problem_definition(hf_dataset.info.description)
ids = pb_def.get_split("DOE_train")[:2]
dataset, _ = huggingface_dataset_to_plaid(hf_dataset, ids=ids, processes_number=2)

## Dataset-wide signed-distance function computation

In [None]:
sample = dataset[ids[0]]

print("[before update] 'sdf' in sample fields ?", "sdf" in sample.get_field_names())
updated_sample = update_sample_with_sdf(sample)
print(
    "[after update] 'sdf' in sample fields ?", "sdf" in updated_sample.get_field_names()
)

In [None]:
print(
    "[before update] 'sdf' in dataset fields ?",
    "sdf" in dataset[ids[0]].get_field_names(),
)
updated_dataset = update_dataset_with_sdf(dataset)
print(
    "[after update] 'sdf' in dataset fields ?",
    "sdf" in updated_dataset[ids[0]].get_field_names(),
)

In [None]:
sample = dataset[ids[0]]
computed_sdf = compute_sdf(sample)

img_name = "feature_engineering_1.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_field(
        sample,
        computed_sdf,
        title="SDF illustration",
        scalar_bar_args={"title": "sdf"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)

This computation relies on the finite element engine provided by Muscat. It computes the exact distance from each node to the boundary, i.e., to the surface elements that define the mesh boundary (for both 2D and 3D meshes seamlessly).

We now illustrate the error introduced by using a naive computation of the SDF, which measures the distance to the nearest point on the boundary.

In [None]:
import numpy as np
from Muscat.Bridges.CGNSBridge import CGNSToMesh
from scipy.spatial import KDTree

mesh = CGNSToMesh(dataset[ids[0]].get_mesh())

ids_holes = mesh.GetNodalTag("Holes").GetIds()
ids_ext_boundary = mesh.GetNodalTag("Ext_bound").GetIds()

kdtree = KDTree(mesh.nodes[np.hstack((ids_holes, ids_ext_boundary))])

naive_sdf, _ = kdtree.query(mesh.nodes)

difference_sdf = computed_sdf - naive_sdf

img_name = "feature_engineering_2.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_field(
        sample,
        difference_sdf,
        title="SDF error computation",
        scalar_bar_args={"title": "sdf error"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)