-
Notifications
You must be signed in to change notification settings - Fork 10
/
model.py
1695 lines (1319 loc) · 56 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import dataclasses
import functools
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union
import jax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import rod
from jax_dataclasses import Static
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
import jaxsim.physics.algos.rnea
import jaxsim.physics.model.physics_model
import jaxsim.physics.model.physics_model_state
import jaxsim.typing as jtp
from jaxsim import high_level, logging, physics, sixd
from jaxsim.physics.algos import soft_contacts
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.utils import JaxsimDataclass, Mutability, Vmappable, oop
from .common import VelRepr
@jax_dataclasses.pytree_dataclass
class ModelData(JaxsimDataclass):
"""
Class used to store the model state and input at a given time.
"""
model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState
model_input: jaxsim.physics.model.physics_model_state.PhysicsModelInput
contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState
@staticmethod
def zero(physics_model: physics.model.physics_model.PhysicsModel) -> "ModelData":
"""
Return a ModelData object with all fields set to zero and initialized with the right shape.
Args:
physics_model: The considered physics model.
Returns:
The zero ModelData object of the given physics model.
"""
return ModelData(
model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState.zero(
physics_model=physics_model
),
model_input=jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero(
physics_model=physics_model
),
contact_state=jaxsim.physics.algos.soft_contacts.SoftContactsState.zero(
physics_model=physics_model
),
)
@jax_dataclasses.pytree_dataclass
class StepData(JaxsimDataclass):
"""
Class used to store the data computed at each step of the simulation.
"""
t0: float
tf: float
dt: float
# Starting model data and real input (tau, f_ext) computed at t0
t0_model_data: ModelData = dataclasses.field(repr=False)
t0_model_input_real: jaxsim.physics.model.physics_model_state.PhysicsModelInput = (
dataclasses.field(repr=False)
)
# ABA output
t0_base_acceleration: jtp.Vector = dataclasses.field(repr=False)
t0_joint_acceleration: jtp.Vector = dataclasses.field(repr=False)
# (new ODEState)
# Starting from t0_model_data, can be obtained by integrating the ABA output
# and tangential_deformation_dot (which is fn of ode_state at t0)
tf_model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState = (
dataclasses.field(repr=False)
)
tf_contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState = (
dataclasses.field(repr=False)
)
aux: Dict[str, Any] = dataclasses.field(default_factory=dict)
@jax_dataclasses.pytree_dataclass
class Model(Vmappable):
"""
High-level class to operate on a simulated model.
"""
model_name: Static[str]
physics_model: physics.model.physics_model.PhysicsModel = dataclasses.field(
repr=False
)
velocity_representation: Static[VelRepr] = dataclasses.field(default=VelRepr.Mixed)
data: ModelData = dataclasses.field(default=None, repr=False)
# ========================
# Initialization and state
# ========================
@staticmethod
def build_from_model_description(
model_description: Union[str, pathlib.Path, rod.Model],
model_name: str | None = None,
vel_repr: VelRepr = VelRepr.Mixed,
gravity: jtp.Array = jaxsim.physics.default_gravity(),
is_urdf: bool | None = None,
considered_joints: List[str] | None = None,
) -> "Model":
"""
Build a Model object from a model description.
Args:
model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model.
model_name: The optional name of the model that overrides the one in the description.
vel_repr: The velocity representation to use.
gravity: The 3D gravity vector.
is_urdf: Whether the model description is a URDF or an SDF. This is automatically inferred if the model description is a path to a file.
considered_joints: The list of joints to consider. If None, all joints are considered.
Returns:
The built Model object.
"""
import jaxsim.parsers.rod
# Parse the input resource (either a path to file or a string with the URDF/SDF)
# and build the -intermediate- model description
model_description = jaxsim.parsers.rod.build_model_description(
model_description=model_description, is_urdf=is_urdf
)
# Lump links together if not all joints are considered.
# Note: this procedure assigns a zero position to all joints not considered.
if considered_joints is not None:
model_description = model_description.reduce(
considered_joints=considered_joints
)
# Create the physics model from the model description
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
model_description=model_description, gravity=gravity
)
# Build and return the high-level model
return Model.build(
physics_model=physics_model,
model_name=model_name,
vel_repr=vel_repr,
)
@staticmethod
def build_from_sdf(
sdf: Union[str, pathlib.Path],
model_name: str | None = None,
vel_repr: VelRepr = VelRepr.Mixed,
gravity: jtp.Array = jaxsim.physics.default_gravity(),
is_urdf: bool | None = None,
considered_joints: List[str] | None = None,
) -> "Model":
"""
Build a Model object from an SDF description.
This is a deprecated method, use build_from_model_description instead.
"""
msg = "Model.{} is deprecated, use Model.{} instead."
logging.warning(
msg=msg.format("build_from_sdf", "build_from_model_description")
)
return Model.build_from_model_description(
model_description=sdf,
model_name=model_name,
vel_repr=vel_repr,
gravity=gravity,
is_urdf=is_urdf,
considered_joints=considered_joints,
)
@staticmethod
def build(
physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
model_name: str | None = None,
vel_repr: VelRepr = VelRepr.Mixed,
) -> "Model":
"""
Build a Model object from a physics model.
Args:
physics_model: The physics model.
model_name: The optional name of the model that overrides the one in the physics model.
vel_repr: The velocity representation to use.
Returns:
The built Model object.
"""
# Set the model name (if not provided, use the one from the model description)
model_name = (
model_name if model_name is not None else physics_model.description.name
)
# Build the high-level model
model = Model(
physics_model=physics_model,
model_name=model_name,
velocity_representation=vel_repr,
)
# Zero the model data
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
model.zero()
# Check model validity
if not model.valid():
raise RuntimeError("The model is not valid.")
# Return the high-level model
return model
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
def reduce(
self, considered_joints: tuple[str, ...], keep_base_pose: bool = False
) -> None:
"""
Reduce the model by lumping together the links connected by removed joints.
Args:
considered_joints: The sequence of joints to consider.
keep_base_pose: A flag indicating whether to keep the base pose or not.
"""
if self.vectorized:
raise RuntimeError("Cannot reduce a vectorized model.")
# Reduce the model description.
# If considered_joints contains joints not existing in the model, the method
# will raise an exception.
reduced_model_description = self.physics_model.description.reduce(
considered_joints=list(considered_joints)
)
# Create the physics model from the reduced model description
physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
model_description=reduced_model_description,
gravity=self.physics_model.gravity[0:3],
)
# Build the reduced high-level model
reduced_model = Model.build(
physics_model=physics_model,
model_name=self.name(),
vel_repr=self.velocity_representation,
)
# Extract the base pose
W_p_B = self.base_position()
W_Q_B = self.base_orientation(dcm=False)
# Replace the current model with the reduced model.
# Since the structure of the PyTree changes, we disable validation.
self.physics_model = reduced_model.physics_model
self.data = reduced_model.data
if keep_base_pose:
self.reset_base_position(position=W_p_B)
self.reset_base_orientation(orientation=W_Q_B, dcm=False)
@functools.partial(oop.jax_tf.method_rw, jit=False)
def zero(self) -> None:
""""""
self.data = ModelData.zero(physics_model=self.physics_model)
@functools.partial(oop.jax_tf.method_rw, jit=False)
def zero_input(self) -> None:
""""""
self.data.model_input = ModelData.zero(
physics_model=self.physics_model
).model_input
@functools.partial(oop.jax_tf.method_rw, jit=False)
def zero_state(self) -> None:
""""""
model_data_zero = ModelData.zero(physics_model=self.physics_model)
self.data.model_state = model_data_zero.model_state
self.data.contact_state = model_data_zero.contact_state
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False)
def set_velocity_representation(self, vel_repr: VelRepr) -> None:
""""""
if self.velocity_representation is vel_repr:
return
self.velocity_representation = vel_repr
# ==========
# Properties
# ==========
@functools.partial(oop.jax_tf.method_ro, jit=False)
def valid(self) -> jtp.Bool:
""""""
valid = True
valid = valid and all([l.valid() for l in self.links()])
valid = valid and all([j.valid() for j in self.joints()])
return jnp.array(valid, dtype=bool)
@functools.partial(oop.jax_tf.method_ro, jit=False)
def floating_base(self) -> jtp.Bool:
""""""
return jnp.array(self.physics_model.is_floating_base, dtype=bool)
@functools.partial(oop.jax_tf.method_ro, jit=False)
def dofs(self) -> jtp.Int:
""""""
return self.joint_positions().size
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def name(self) -> str:
""""""
return self.model_name
@functools.partial(oop.jax_tf.method_ro, jit=False)
def nr_of_links(self) -> jtp.Int:
""""""
return jnp.array(len(self.links()), dtype=int)
@functools.partial(oop.jax_tf.method_ro, jit=False)
def nr_of_joints(self) -> jtp.Int:
""""""
return jnp.array(len(self.joints()), dtype=int)
@functools.partial(oop.jax_tf.method_ro)
def total_mass(self) -> jtp.Float:
""""""
return jnp.sum(jnp.array([l.mass() for l in self.links()]), dtype=float)
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def get_link(self, link_name: str) -> high_level.link.Link:
""""""
if link_name not in self.link_names():
msg = f"Link '{link_name}' is not part of model '{self.name()}'"
raise ValueError(msg)
return self.links(link_names=(link_name,))[0]
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def get_joint(self, joint_name: str) -> high_level.joint.Joint:
""""""
if joint_name not in self.joint_names():
msg = f"Joint '{joint_name}' is not part of model '{self.name()}'"
raise ValueError(msg)
return self.joints(joint_names=(joint_name,))[0]
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def link_names(self) -> tuple[str, ...]:
""""""
return tuple(self.physics_model.description.links_dict.keys())
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def joint_names(self) -> tuple[str, ...]:
""""""
return tuple(self.physics_model.description.joints_dict.keys())
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def links(
self, link_names: tuple[str, ...] | None = None
) -> tuple[high_level.link.Link, ...]:
""""""
all_links = {
l.name: high_level.link.Link(
link_description=l, _parent_model=self, batch_size=self.batch_size
)
for l in sorted(
self.physics_model.description.links_dict.values(),
key=lambda l: l.index,
)
}
for l in all_links.values():
l._set_mutability(self._mutability())
if link_names is None:
return tuple(all_links.values())
return tuple(all_links[name] for name in link_names)
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def joints(
self, joint_names: tuple[str, ...] | None = None
) -> tuple[high_level.joint.Joint, ...]:
""""""
all_joints = {
j.name: high_level.joint.Joint(
joint_description=j, _parent_model=self, batch_size=self.batch_size
)
for j in sorted(
self.physics_model.description.joints_dict.values(),
key=lambda j: j.index,
)
}
for j in all_joints.values():
j._set_mutability(self._mutability())
if joint_names is None:
return tuple(all_joints.values())
return tuple(all_joints[name] for name in joint_names)
@functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"])
def in_contact(
self,
link_names: tuple[str, ...] | None = None,
terrain: Terrain = FlatTerrain(),
) -> jtp.Vector:
""""""
link_names = link_names if link_names is not None else self.link_names()
if set(link_names) - set(self.link_names()) != set():
raise ValueError("One or more link names are not part of the model")
from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
W_p_Ci, _ = collidable_points_pos_vel(
model=self.physics_model,
q=self.data.model_state.joint_positions,
qd=self.data.model_state.joint_velocities,
xfb=self.data.model_state.xfb(),
)
terrain_height = jax.vmap(terrain.height)(W_p_Ci[0, :], W_p_Ci[1, :])
below_terrain = W_p_Ci[2, :] <= terrain_height
links_in_contact = jax.vmap(
lambda link_index: jnp.where(
self.physics_model.gc.body == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
)(jnp.array([link.index() for link in self.links(link_names=link_names)]))
return links_in_contact
# =================
# Multi-DoF methods
# =================
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_positions(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
""""""
return self.data.model_state.joint_positions[
self._joint_indices(joint_names=joint_names)
]
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_random_positions(
self,
joint_names: tuple[str, ...] | None = None,
key: jax.Array | None = None,
) -> jtp.Vector:
""""""
if key is None:
key = jax.random.PRNGKey(seed=0)
s_min, s_max = self.joint_limits(joint_names=joint_names)
s_random = jax.random.uniform(
minval=s_min,
maxval=s_max,
key=key,
shape=s_min.shape,
)
return s_random
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_velocities(
self, joint_names: tuple[str, ...] | None = None
) -> jtp.Vector:
""""""
return self.data.model_state.joint_velocities[
self._joint_indices(joint_names=joint_names)
]
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_generalized_forces_targets(
self, joint_names: tuple[str, ...] | None = None
) -> jtp.Vector:
""""""
return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)]
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_limits(
self, joint_names: tuple[str, ...] | None = None
) -> Tuple[jtp.Vector, jtp.Vector]:
""""""
# Consider all joints if not specified otherwise
joint_names = joint_names if joint_names is not None else self.joint_names()
# Create a (Dofs, 2) matrix containing the joint limits
limits = jnp.vstack(
jnp.array([j.position_limit() for j in self.joints(joint_names)])
)
# Get the limits, reordering them in case low > high
s_low = jnp.min(limits, axis=1)
s_high = jnp.max(limits, axis=1)
return s_low, s_high
# =========
# Base link
# =========
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def base_frame(self) -> str:
""""""
return self.physics_model.description.root.name
@functools.partial(oop.jax_tf.method_ro)
def base_position(self) -> jtp.Vector:
""""""
return self.data.model_state.base_position.squeeze()
@functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"])
def base_orientation(self, dcm: bool = False) -> jtp.Vector:
""""""
# Normalize the quaternion before using it.
# Our integration logic has a Baumgarte stabilization term makes the quaternion
# norm converge to 1, but it does not enforce to be 1 at all the time instants.
base_unit_quaternion = (
self.data.model_state.base_quaternion.squeeze()
/ jnp.linalg.norm(self.data.model_state.base_quaternion)
)
# wxyz -> xyzw
to_xyzw = np.array([1, 2, 3, 0])
return (
base_unit_quaternion
if not dcm
else sixd.so3.SO3.from_quaternion_xyzw(
base_unit_quaternion[to_xyzw]
).as_matrix()
)
@functools.partial(oop.jax_tf.method_ro)
def base_transform(self) -> jtp.MatrixJax:
""""""
W_R_B = self.base_orientation(dcm=True)
W_p_B = jnp.vstack(self.base_position())
return jnp.vstack(
[
jnp.block([W_R_B, W_p_B]),
jnp.array([0, 0, 0, 1]),
]
)
@functools.partial(oop.jax_tf.method_ro)
def base_velocity(self) -> jtp.Vector:
""""""
W_v_WB = jnp.hstack(
[
self.data.model_state.base_linear_velocity,
self.data.model_state.base_angular_velocity,
]
)
return self.inertial_to_active_representation(array=W_v_WB)
@functools.partial(oop.jax_tf.method_ro)
def external_forces(self) -> jtp.Matrix:
"""
Return the active external forces acting on the robot.
The external forces are a user input and are not computed by the physics engine.
During the simulation, these external forces are summed to other terms like
the external forces due to the contact with the environment.
Returns:
A matrix of shape (n_links, 6) containing the external forces acting on the
robot links. The forces are expressed in the active representation.
"""
# Get the active external forces that are always stored internally
# in Inertial representation
W_f_ext = self.data.model_input.f_ext
inertial_to_active = lambda f: self.inertial_to_active_representation(
f, is_force=True
)
return jax.vmap(inertial_to_active, in_axes=0)(W_f_ext)
# =======================
# Single link r/w methods
# =======================
@functools.partial(
oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
)
def apply_external_force_to_link(
self,
link_name: str,
force: jtp.Array | None = None,
torque: jtp.Array | None = None,
additive: bool = True,
) -> None:
""""""
# Get the target link with the correct mutability
link = self.get_link(link_name=link_name)
link._set_mutability(mutability=self._mutability())
# Initialize zero force components if not set
force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)
# Build the target 6D force in the active representation
f_ext = jnp.hstack([force, torque])
# Convert the 6D force to the inertial representation
if self.velocity_representation is VelRepr.Inertial:
W_f_ext = f_ext
elif self.velocity_representation is VelRepr.Body:
L_f_ext = f_ext
W_H_L = link.transform()
L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint()
W_f_ext = L_X_W.transpose() @ L_f_ext
elif self.velocity_representation is VelRepr.Mixed:
LW_f_ext = f_ext
W_p_L = link.transform()[0:3, 3]
W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint()
W_f_ext = LW_X_W.transpose() @ LW_f_ext
else:
raise ValueError(self.velocity_representation)
# Obtain the new 6D force considering the 'additive' flag
W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext
# Update the model data
self.data.model_input.f_ext = self.data.model_input.f_ext.at[
link.index(), :
].set(new_force)
@functools.partial(
oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
)
def apply_external_force_to_link_com(
self,
link_name: str,
force: jtp.Array | None = None,
torque: jtp.Array | None = None,
additive: bool = True,
) -> None:
""""""
# Get the target link with the correct mutability
link = self.get_link(link_name=link_name)
link._set_mutability(mutability=self._mutability())
# Initialize zero force components if not set
force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)
# Build the target 6D force in the active representation
f_ext = jnp.hstack([force, torque])
# Convert the 6D force to the inertial representation
if self.velocity_representation is VelRepr.Inertial:
W_f_ext = f_ext
elif self.velocity_representation is VelRepr.Body:
GL_f_ext = f_ext
W_H_L = link.transform()
L_p_CoM = link.com_position(in_link_frame=True)
L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM)
W_H_GL = W_H_L @ L_H_GL
GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint()
W_f_ext = GL_X_W.transpose() @ GL_f_ext
elif self.velocity_representation is VelRepr.Mixed:
GW_f_ext = f_ext
W_p_CoM = link.com_position(in_link_frame=False)
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint()
W_f_ext = GW_X_W.transpose() @ GW_f_ext
else:
raise ValueError(self.velocity_representation)
# Obtain the new 6D force considering the 'additive' flag
W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
new_force = W_f_ext_current + W_f_ext if additive else W_f_ext
# Update the model data
self.data.model_input.f_ext = self.data.model_input.f_ext.at[
link.index(), :
].set(new_force)
# ================================================
# Generalized methods and free-floating quantities
# ================================================
@functools.partial(oop.jax_tf.method_ro)
def generalized_position(self) -> Tuple[jtp.Matrix, jtp.Vector]:
""""""
return self.base_transform(), self.joint_positions()
@functools.partial(oop.jax_tf.method_ro)
def generalized_velocity(self) -> jtp.Vector:
""""""
return jnp.hstack([self.base_velocity(), self.joint_velocities()])
@functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
def generalized_free_floating_jacobian(
self, output_vel_repr: VelRepr | None = None
) -> jtp.Matrix:
""""""
if output_vel_repr is None:
output_vel_repr = self.velocity_representation
# The body frame of the Link.jacobian method is the link frame L.
# In this method, we want instead to use the base link B as body frame.
# Therefore, we always get the link jacobian having Inertial as output
# representation, and then we convert it to the desired output representation.
if output_vel_repr is VelRepr.Inertial:
to_output = lambda J: J
elif output_vel_repr is VelRepr.Body:
def to_output(W_J_Wi):
W_H_B = self.base_transform()
B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
return B_X_W @ W_J_Wi
elif output_vel_repr is VelRepr.Mixed:
def to_output(W_J_Wi):
W_H_B = self.base_transform()
W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
return BW_X_W @ W_J_Wi
else:
raise ValueError(output_vel_repr)
# Get the link jacobians in Inertial representation and convert them to the
# target output representation in which the body frame is the base link B
J_free_floating = jnp.vstack(
[
to_output(
self.get_link(link_name=link_name).jacobian(
output_vel_repr=VelRepr.Inertial
)
)
for link_name in self.link_names()
]
)
return J_free_floating
@functools.partial(oop.jax_tf.method_ro)
def free_floating_mass_matrix(self) -> jtp.Matrix:
""""""
M_body = jaxsim.physics.algos.crba.crba(
model=self.physics_model,
q=self.data.model_state.joint_positions,
)
if self.velocity_representation is VelRepr.Body:
return M_body
elif self.velocity_representation is VelRepr.Inertial:
zero_6n = jnp.zeros(shape=(6, self.dofs()))
B_X_W = sixd.se3.SE3.from_matrix(self.base_transform()).inverse().adjoint()
invT = jnp.vstack(
[
jnp.block([B_X_W, zero_6n]),
jnp.block([zero_6n.T, jnp.eye(self.dofs())]),
]
)
return invT.T @ M_body @ invT
elif self.velocity_representation is VelRepr.Mixed:
zero_6n = jnp.zeros(shape=(6, self.dofs()))
W_H_BW = self.base_transform().at[0:3, 3].set(jnp.zeros(3))
BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
invT = jnp.vstack(
[
jnp.block([BW_X_W, zero_6n]),
jnp.block([zero_6n.T, jnp.eye(self.dofs())]),
]
)
return invT.T @ M_body @ invT
else:
raise ValueError(self.velocity_representation)
@functools.partial(oop.jax_tf.method_ro)
def free_floating_bias_forces(self) -> jtp.Vector:
""""""
with self.editable(validate=True) as model:
model.zero_input()
return jnp.hstack(
model.inverse_dynamics(
base_acceleration=jnp.zeros(6), joint_accelerations=None
)
)
@functools.partial(oop.jax_tf.method_ro)
def free_floating_gravity_forces(self) -> jtp.Vector:
""""""
with self.editable(validate=True) as model:
model.zero_input()
model.data.model_state.joint_velocities = jnp.zeros_like(
model.data.model_state.joint_velocities
)
model.data.model_state.base_linear_velocity = jnp.zeros_like(
model.data.model_state.base_linear_velocity
)
model.data.model_state.base_angular_velocity = jnp.zeros_like(
model.data.model_state.base_angular_velocity
)
return jnp.hstack(
model.inverse_dynamics(
base_acceleration=jnp.zeros(6), joint_accelerations=None
)
)
@functools.partial(oop.jax_tf.method_ro)
def momentum(self) -> jtp.Vector:
""""""
with self.editable(validate=True) as m:
m.set_velocity_representation(vel_repr=VelRepr.Body)
# Compute the momentum in body-fixed velocity representation.
# Note: the first 6 rows of the mass matrix define the jacobian of the
# floating-base momentum.
B_h = m.free_floating_mass_matrix()[0:6, :] @ m.generalized_velocity()
W_H_B = self.base_transform()
B_X_W: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
W_h = B_X_W.T @ B_h
return self.inertial_to_active_representation(array=W_h, is_force=True)
# ===========
# CoM methods
# ===========
@functools.partial(oop.jax_tf.method_ro)
def com_position(self) -> jtp.Vector:
""""""
m = self.total_mass()
W_H_L = self.forward_kinematics()
W_H_B = self.base_transform()
B_H_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().as_matrix()
com_links = [
(
l.mass()
* B_H_W
@ W_H_L[l.index()]
@ jnp.hstack([l.com_position(in_link_frame=True), 1])
)
for l in self.links()
]
B_ph_CoM = (1 / m) * jnp.sum(jnp.array(com_links), axis=0)
return (W_H_B @ B_ph_CoM)[0:3]
# ==========
# Algorithms
# ==========
@functools.partial(oop.jax_tf.method_ro)
def forward_kinematics(self) -> jtp.Array:
""""""
W_H_i = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
model=self.physics_model,
q=self.data.model_state.joint_positions,
xfb=self.data.model_state.xfb(),
)
return W_H_i
@functools.partial(oop.jax_tf.method_ro)
def inverse_dynamics(
self,
joint_accelerations: jtp.Vector | None = None,
base_acceleration: jtp.Vector | None = None,
) -> Tuple[jtp.Vector, jtp.Vector]:
"""
Compute inverse dynamics with the RNEA algorithm.
Args:
joint_accelerations: the joint accelerations to consider.
base_acceleration: the base acceleration in the active representation to consider.
Returns:
A tuple containing the 6D force in active representation applied to the base
to obtain the considered base acceleration, and the joint torques to apply
to obtain the considered joint accelerations.
"""
# Build joint accelerations if not provided
joint_accelerations = (
joint_accelerations
if joint_accelerations is not None
else jnp.zeros_like(self.joint_positions())
)
# Build base acceleration if not provided
base_acceleration = (
base_acceleration if base_acceleration is not None else jnp.zeros(6)
)
if base_acceleration.size != 6:
raise ValueError(base_acceleration.size)
def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC):
W_X_C = sixd.se3.SE3.from_matrix(W_H_C).adjoint()
C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint()