# Transformation examples

## Introduction
This notebook illustrates some tranformation capabilities provided by plaid-ops. 

In [None]:
import os

import numpy as np
from datasets import load_dataset
from IPython.display import Image as IPyImage
from IPython.display import display
from PIL import Image as PILImage
from plaid.bridges.huggingface_bridge import (
    huggingface_dataset_to_plaid,
    huggingface_description_to_problem_definition,
)

from plaid_ops.common.visualization import plot_field, plot_sample_field
from plaid_ops.mesh.transformations import (
    compute_bounding_box,
    project_on_other_datset,
    project_on_regular_grid,
)

hf_dataset = load_dataset(
    "PLAID-datasets/2D_Multiscale_Hyperelasticity", split="all_samples"
)
pb_def = huggingface_description_to_problem_definition(hf_dataset.info.description)
ids = pb_def.get_split("DOE_train")[:2]
dataset, _ = huggingface_dataset_to_plaid(hf_dataset, ids=ids, processes_number=2)

## Dataset-wide projection for a constant rectilinear mesh

We start by illustrating the `u1` field from the first sample:

In [None]:
img_name = "transformation_1.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_sample_field(
        dataset[ids[0]],
        "u1",
        title="Unstructured mesh",
        show_edges=True,
        scalar_bar_args={"title": "u1"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)

Then, we project all the dataset's meshes onto a constant rectilinear mesh. process works seemlessly for 2D and 3D meshes, and rely on finite element interpolation, that exploit the order on the underlying finite element representation of the solution:

In [None]:
bbox = compute_bounding_box(dataset)
dims = [101, 101]
projected_dataset = project_on_regular_grid(
    dataset, dimensions=dims, bbox=bbox, verbose=True
)

img_name = "transformation_2.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_sample_field(
        projected_dataset[ids[0]],
        "u1",
        title="Projection on regular grid mesh",
        show_edges=True,
        scalar_bar_args={"title": "u1"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)

We can easily project back to the initial meshes of the dataset, using again finite element interpolation: 

In [None]:
inv_projected_dataset = project_on_other_datset(
    projected_dataset, dataset, verbose=True
)

img_name = "transformation_3.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_sample_field(
        inv_projected_dataset[ids[0]],
        "u1",
        title="Projection back to inital mesh",
        show_edges=True,
        scalar_bar_args={"title": "u1"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)

We compute the error made by both projection, and illustrate it:

In [None]:
error = inv_projected_dataset[ids[0]].get_field("u1") - dataset[ids[0]].get_field("u1")

img_name = "transformation_4.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_field(
        dataset[ids[0]],
        error,
        title="u1 error from projection and inverse projection",
        show_edges=True,
        scalar_bar_args={"title": "u1 error"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)

We compute the norm of the error made by direct and inverse projections:

In [None]:
print(f"u1 error norm from projection and inverse projection = {np.linalg.norm(error)}")

Now, we compare our approach with a more naive one that rely on the value of the field of the nearest node in the target mesh:

In [None]:
from Muscat.Bridges.CGNSBridge import CGNSToMesh, MeshToCGNS
from Muscat.MeshTools.ConstantRectilinearMeshTools import CreateConstantRectilinearMesh
from plaid.containers.sample import Sample
from scipy.spatial import KDTree

Computation of the direct projection:

In [None]:
spacing = np.divide(bbox[1] - bbox[0], np.array(dims) - 1)
background_mesh = CreateConstantRectilinearMesh(
    dimensions=dims, origin=bbox[0], spacing=spacing
)
naive_proj_sample = Sample()
naive_proj_sample.add_tree(MeshToCGNS(background_mesh))

mesh = CGNSToMesh(dataset[ids[0]].get_mesh())
u1 = dataset[ids[0]].get_field("u1")

kdtree = KDTree(mesh.nodes)
_, id_bg_nodes = kdtree.query(background_mesh.nodes)

naive_proj_u1 = u1[id_bg_nodes]

Computation of the inverse projection and illustration of the error:

In [None]:
kdtree = KDTree(background_mesh.nodes)
_, id_bg_nodes = kdtree.query(mesh.nodes)

_, id_origin_nodes = kdtree.query(mesh.nodes)

naive_inv_proj_u1 = naive_proj_u1[id_bg_nodes]

error = naive_inv_proj_u1 - dataset[ids[0]].get_field("u1")

img_name = "transformation_5.png"
if os.environ.get("READTHEDOCS") == "True" or os.environ.get("GITHUB_ACTIONS"):
    display(IPyImage(filename=img_name))
else:
    img_array = plot_field(
        dataset[ids[0]],
        error,
        title="u1 error from naive projection and inverse projection",
        show_edges=True,
        scalar_bar_args={"title": "u1 error"},
    )
    img = PILImage.fromarray(img_array)
    img.save(img_name)
    display(img)

We compute the norm of the error made by naive direct and inverse projections:

In [None]:
print(f"u1 error norm from projection and inverse projection = {np.linalg.norm(error)}")