In [None]:
from pydrake.all import (
    DiagramBuilder,
    AddMultibodyPlantSceneGraph,
    QueryObject,
    plot_system_graphviz,
    PoseBundle,
    LeafSystem,
    Value,
    MultibodyPlant,
    FramePoseVector,
)

In [None]:
class TaichiMPMSystem(LeafSystem):
    def __init__(self):
        super().__init__()

        self.set_name("taichi_mpm")
        self.DeclareAbstractInputPort("query_object", Value[QueryObject]())
        self.DeclareAbstractOutputPort(
            "particles_pose",
            alloc=lambda: Value[FramePoseVector](),
            calc=lambda merp: None,
            prerequisites_of_calc={self.all_state_ticket()},
        )

In [None]:
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, 0.0)
plant.Finalize()
diagram = builder.Build()
context = plant.CreateDefaultContext()
plant.GetOutputPort("geometry_pose").Eval(context)

In [None]:
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, 0.0)
id = scene_graph.RegisterSource("particles")
mpm = builder.AddSystem(TaichiMPMSystem())
builder.Connect(mpm.GetOutputPort("particles_pose"), scene_graph.get_source_pose_port(id))
builder.Connect(scene_graph.get_query_output_port(), mpm.GetInputPort("query_object"))
diagram = builder.Build()
diagram.set_name("diagram")

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(20, 20))
plot_system_graphviz(diagram)