# FIGConvNet inference and visualization notebook

This notebook demonstrates how to use pre-trained FIGConvNet model to perform inference
on DrivAerNet dataset.

The following items are required and need to be downloaded separately:

* Pre-trained [FIGConvNet checkpoint](to_be_provided).
* [DrivAerNet](https://github.com/Mohamedelrefaie/DrivAerNet/tree/main/DrivAerNet_v1) dataset.
    The dataset needs to be converted to a Webdataset format. For simplicity, the small subset
    of the dataset has been already converted to Webdataset format and can be used in this
    example as-is.

The inputs to the model are:
* Point cloud representing the surface of the car.

The outputs of the model are:
* Pressure at each surface point.
* Wall shear stresses at each surface point.

Before we begin, let's import some common packages.

In [1]:
from pathlib import Path
import sys

import numpy as np
import pyvista as pv
import torch
import warp as wp

if sys.path[0] != "..":
    sys.path.insert(0, "..")

device = torch.device("cuda:0")
torch.cuda.device(device)
wp.init()
wp.set_device(str(device))


Warp 1.0.2 initialized:
   CUDA Toolkit 11.5, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA GeForce RTX 3090" (24 GiB, sm_86, mempool enabled)
     "cuda:1"   : "NVIDIA TITAN RTX" (24 GiB, sm_75, mempool enabled)
   CUDA peer access:
     Not supported
   Kernel cache:
     /home/du/.cache/warp/1.0.2


## Dataset visualization

This section provides visualizations of the original, mesh-based dataset.

In [2]:
# Path to the dataset and pressure VTK files.
# Note: update `drivaer_orig_path` as needed.
drivaer_orig_path = Path("/data/src/modulus/data/DrivAerNet/mini")
vtk_path = drivaer_orig_path / "SurfacePressureVTK"

output_dir = drivaer_orig_path.parent / "vis"

In [3]:
# Design id 0001 corresponds to DrivAer_F_D_WM_WW_0001.vtk file which belongs to the test set.
design_ids = ["0001"]


camera_position = [
    (-3., -4.5, 5),
    (1, 0, 0.6),
    (0.1, 0.1, 0.9),
]


def create_mesh_vis(
    mesh: pv.PolyData,
    plotter: pv.Plotter,
    camera_position,

):
    plotter.subplot(0, 0)
    # Solid mesh visualization
    plotter.add_mesh(mesh, color="lightgrey")
    plotter.camera_position = camera_position
    plotter.add_text("Mesh", position="upper_left")


def create_mesh_vis_gt(
    mesh: pv.PolyData,
    scalar_name: str,
    plotter: pv.Plotter,
    camera_position,
):
    plotter.subplot(0, 1)
    # Solid mesh visualization with scalar.
    plotter.add_mesh(
        mesh,
        scalars=scalar_name,
        cmap="jet",
        clim=(-600, 400),
        show_scalar_bar=False,
    )
    plotter.camera_position = camera_position
    plotter.add_text('GT Pressure', position='upper_right')


def visualize_meshes(
    output_dir: Path = None,
):
    for design in design_ids:
        mesh_file = vtk_path / f"DrivAer_F_D_WM_WW_{design}.vtk"
        mesh = pv.read(mesh_file)

        plotter = pv.Plotter(shape=(1, 2))
        # Create input mesh vis.
        create_mesh_vis(mesh, plotter, camera_position)
        # Create GT pressure vis.
        create_mesh_vis_gt(mesh, "p", plotter, camera_position)

        if output_dir:
            output_dir.mkdir(parents=True, exist_ok=True)
            plotter.save_graphic(output_dir / f"{design}_gt_p_mesh.pdf")
        else:
            plotter.show()


output_dir = None
# Uncomment and update the below to render to pdf files instead.
# output_dir = Path("/data/src/modulus/data/DrivAerNet/vis")
visualize_meshes(output_dir)


Widget(value='<iframe src="http://localhost:42125/index.html?ui=P_0x7f21d748c2b0_0&reconnect=auto" class="pyvi…

## Running model inference

This section provides an example of running model inference on the examples from DrivAerNet dataset.

For simplicity, the same dataloader is used in inference as in training (Webdataset-based).

### Creating the dataloader

Instantiate the dataloader first. Please update the path to the Webdataset as needed.

In [4]:
import src.data


num_points = 65536 #2048
dataset = src.data.DrivAerNetDataModule(
    drivaer_orig_path.parent / "drivaernet_webdataset",
    num_points=num_points,
    preprocessors=[
        src.data.drivaernet_datamodule.DrivAerNetPreprocessor(num_points)
    ]
)

### Creating the model

The following code creates a model and loads pre-trained weights from a checkpoint file.
The model must be instantiated with exactly the same arguments that were used
when the model was trained.

Note: update the path to the checkpoint as needed.

In [5]:
import src.networks
from modulus.models.figconvnet.geometries import GridFeaturesMemoryFormat


model = src.networks.FIGConvUNetDrivAerNet(
  aabb_max=[2.75, 1.5, 1.0],
  aabb_min=[-2.75, -1.5, -1.0],
  hidden_channels=[16, 16, 16],
  in_channels=1,
  kernel_size=5,
  mlp_channels=[2048, 2048],
  neighbor_search_type="radius",
  num_down_blocks=1,
  num_levels=2,
  out_channels=1,
  pooling_layers=[2],
  pooling_type="max",
  reductions=["mean"],
  resolution_memory_format_pairs=[
    (GridFeaturesMemoryFormat.b_xc_y_z, [  5, 150, 100]),
    (GridFeaturesMemoryFormat.b_yc_x_z, [250,   3, 100]),
    (GridFeaturesMemoryFormat.b_zc_x_y, [250, 150,   2]),
  ],
  use_rel_pos_encode=True,
)
# Load checkpoint.
chk = torch.load("/data/src/modulus/models/fignet/2060187/model_00103.pth")
model.load_state_dict(chk["model"])
model = model.to(device)
model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


FIGConvUNetDrivAerNet(
  (point_feature_to_grids): ModuleList(
    (0): Sequential(
      (0): PointFeatureToGrid(
        (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)
      )
      (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_xc_y_z)
    )
    (1): Sequential(
      (0): PointFeatureToGrid(
        (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)
      )
      (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_yc_x_z)
    )
    (2): Sequential(
      (0): PointFeatureToGrid(
        (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)
      )
      (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_zc_x_y)
    )
  )
  (down_blocks): ModuleList(
    (0-1): 2 x Sequential(
      (0): GridFeatureConv2D

### Run inference

The following code runs inference for the first sample in the dataset.

In [6]:
torch.set_grad_enabled(False)

sample = next(iter(dataset.train_dataloader()))
vertices = model.data_dict_to_input(sample)
pressure = sample["time_avg_pressure"]
normalized_pred, drag_pred = model(vertices)
pred = dataset.normalizer.decode(normalized_pred)


Module modulus.models.figconvnet.warp_neighbor_search load on device 'cuda:0' took 153.56 ms


### Visualize inputs and predictions

In [10]:
camera_position = [
    (-3., -4.5, 5),
    (0.5, 0, 0.6),
    (0.1, 0.1, 0.9),
]

def plot_results(
    vertices: np.ndarray,
    pred: np.ndarray,
    gt: np.ndarray,
    camera_position,
    scalar_name: str = "p",
):
    plotter = pv.Plotter(shape=(1, 3))

    pc = pv.PolyData(vertices)
    gt_name = scalar_name + "_gt"
    pc[gt_name] = gt
    pred_name = scalar_name + "_pred"
    pc[pred_name] = pred

    plotter.subplot(0, 0)
    plotter.add_points(pv.PolyData(vertices))
    plotter.camera_position = camera_position
    plotter.camera.zoom(0.5)
    plotter.add_text("Input point cloud", position="upper_left")

    plotter.subplot(0, 1)
    plotter.add_mesh(pc, scalars=pred_name, cmap="jet", show_scalar_bar=False)
    plotter.camera_position = camera_position
    plotter.camera.zoom(0.5)
    plotter.add_scalar_bar(title=scalar_name, vertical=True)
    plotter.add_text("Predicted point cloud", position="upper_left")

    plotter.subplot(0, 2)
    plotter.add_mesh(pc, scalars=gt_name, cmap="jet", show_scalar_bar=False)
    plotter.camera_position = camera_position
    plotter.camera.zoom(0.5)
    plotter.add_scalar_bar(title=scalar_name + " ", vertical=True)
    plotter.add_text("GT point cloud", position="upper_left")

    plotter.show()

plot_results(
    vertices[0].cpu().numpy(),
    pred[0].cpu().numpy(),
    pressure[0].cpu().numpy(),
    camera_position,
)


Widget(value='<iframe src="http://localhost:42125/index.html?ui=P_0x7f2144ef6d10_4&reconnect=auto" class="pyvi…