A desired quality for some loss functions is **statefulness**, meaning that it utilizes on an internally stored and mutable variable for the loss computation. \
This is incompatible with the design philosophy of JAX, which forces functions to be **pure**, exatly meaning that the function do not rely on hidden states.

To work around this problem the guys from mosaic adapt the method describe in `common.py`:
   1. a module has a `state_index` property that is an instance of `StateIndex`
   2. part of the aux pytree is a tuple of (StateIndex, value) where value is the argument to be passed to the `update_state` method
   3. the `update_state` method is called for each module in the loss pytree that has a `state_index` property
   This last step happens in the optimization loop after the value and gradient have been computed.

In mosaic a loss function is an object created using the `build_loss` method of any of the supported models. \
It contains multiple relevant parameters for loss computation, but in particular it has a `loss` property containing either a `LinearCombination` or single `LossTerm` object, representing the actual classes that perform the loss computation. \
It is the single `LossTerm` that we want to allow to be stateful.  

We demonstrate an example of a dummy stateful loss and its integration in an optimization pipeline using the `ProtenixTiny` model. \
So first thing first, the model is loaded and the input features representing the target-binder complex are created.

In [1]:
import jax 
import equinox as eqx
import numpy as np
from mosaic.losses import structure_prediction as sp
from mosaic.models.protenix import ProtenixTiny
from mosaic.structure_prediction import TargetChain
import gemmi
%cd "/home/marco/DTU/ms_thesis/code/mosaic_motif_scaffolding"

# Load model
protenix = ProtenixTiny()

# Define binder length and target structure
binder_length = 120
target_structure = gemmi.read_structure("IL7RA.cif")
target_structure.remove_ligands_and_waters()
target_sequence = gemmi.one_letter_code(
    [r.name for r in target_structure[0][0]]
)

# Build complex input features
design_features, design_structure = protenix.binder_features(
    binder_length=binder_length,
    chains=[
        TargetChain(
            target_sequence,
            use_msa=True,
            template_chain=target_structure[0][0],
        )
    ],
)

INFO:2026-02-23 16:31:56,916:jax._src.xla_bridge:834: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2026-02-23 16:31:56,916 [/home/marco/DTU/ms_thesis/code/mosaic_motif_scaffolding/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:834] INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


/home/marco/DTU/ms_thesis/code/mosaic_motif_scaffolding


2026-02-23 16:31:57,897 [/home/marco/DTU/ms_thesis/code/mosaic_motif_scaffolding/.venv/lib/python3.12/site-packages/protenix/web_service/colab_request_utils.py:195] ERROR protenix.web_service.colab_request_utils: Msa server is running.
COMPLETE: 100%|██████████| 100/100 [elapsed: 00:01 estimate remaining: 00:00]

Files downloaded and extracted successfully.



2026-02-23 16:32:03,577 [/home/marco/DTU/ms_thesis/code/mosaic_motif_scaffolding/.venv/lib/python3.12/site-packages/protenix/data/constraint_featurizer.py:392] INFO protenix.data.constraint_featurizer: Loaded constraint feature: #atom contact:0 #contact:0 #pocket:0


In addition to standard losses, we implement a dummy stateful loss. \
As described above, a stateful loss needs to have: 
- a `state_index` property acting as a unique identifier for the loss, 
- the internal `state` property which is the mutable part of the loss, 
- an `update_state` function describing how the state should modified using the new state stored in the output `aux` PyTree.

In [2]:
from mosaic.common import StateIndex, LossTerm
from jaxtyping import Float, Array

# dummy stateful loss
class StatefulLoss(LossTerm):
    state_index: StateIndex
    state: float = 0.0
    
    def __call__(
        self,
        sequence: Float[Array, "N 20"],
        output: sp.AbstractStructureOutput,
        key):
        # Some random loss computation
        loss = 10
        next_state = self.state + 1
        return loss, {"stateful_loss": loss, "dummy": (self.state_index, next_state)}
    
    # remember that the new_state is read from aux which are always jax arrays
    def update_state(self, new_state):
        new_state = new_state.item()
        return eqx.tree_at(lambda s: s.state, self, new_state)

Now we can build and run the loss function...

In [3]:
# Build custom loss function
structure_loss = (
    sp.BinderTargetContact()
    + sp.WithinBinderContact()
    + StatefulLoss(state_index=StateIndex(), state=5.0)
    + StatefulLoss(state_index=StateIndex(), state=0.0)
)

# Add logging of specific features across optimization
loss_fn = protenix.build_multisample_loss(
    loss=structure_loss,
    features=design_features,
    recycling_steps=1,
    sampling_steps=20,
    num_samples=4
)
 
pssm = jax.nn.softmax(0.5 * jax.random.gumbel(key=jax.random.key(np.random.randint(1000000)), shape=(binder_length, 20)))
loss, aux = loss_fn(sequence=pssm, key=jax.random.key(42))

n_msa 209
JIT compiling protenix trunk module...


Inspecting the outputs, we can verify the presence of the `(StateIndex, new_state)` tuple in the aux of the `StatefulLoss`

In [4]:
print(f"loss: {loss}")
for k,v in jax.tree.leaves_with_path(aux): print(f"{k}: {v}")

loss: 30.135160446166992
(SequenceKey(idx=0), DictKey(key='target_contact')): [-5.030611 -5.030611 -5.030611 -5.030611]
(SequenceKey(idx=1), DictKey(key='intra_contact')): [-5.10455 -5.10455 -5.10455 -5.10455]
(SequenceKey(idx=2), DictKey(key='dummy'), SequenceKey(idx=0), GetAttrKey(name='id')): [50340 50340 50340 50340]
(SequenceKey(idx=2), DictKey(key='dummy'), SequenceKey(idx=1)): [6. 6. 6. 6.]
(SequenceKey(idx=2), DictKey(key='stateful_loss')): [10 10 10 10]
(SequenceKey(idx=3), DictKey(key='dummy'), SequenceKey(idx=0), GetAttrKey(name='id')): [26007 26007 26007 26007]
(SequenceKey(idx=3), DictKey(key='dummy'), SequenceKey(idx=1)): [1. 1. 1. 1.]
(SequenceKey(idx=3), DictKey(key='stateful_loss')): [10 10 10 10]
(SequenceKey(idx=4), DictKey(key='features')): []


And using the `is_state_update` function from common we can find it and store it

In [5]:
# extract state updates from the aux
def is_state_update(x):
    return isinstance(x, tuple) and isinstance(x[0], StateIndex)
state_index_to_update = [
            (x[0].id, x[1])
            for x in jax.tree.leaves(aux, is_leaf=is_state_update)
            if is_state_update(x)
            ]
for i,j in state_index_to_update: print(f"id: {i} - new_state: {j}")

# for multisample losses, as standard we only keep the first new generated state
state_index_to_update = {int(k[0].squeeze()): v[0] for k,v in state_index_to_update}
for i,j in state_index_to_update.items(): print(f"id: {i} - new_state: {j}")

id: [50340 50340 50340 50340] - new_state: [6. 6. 6. 6.]
id: [26007 26007 26007 26007] - new_state: [1. 1. 1. 1.]
id: 50340 - new_state: 6.0
id: 26007 - new_state: 1.0


Once collected the new states, we need to match their `id` with the respective losses:

In [None]:
def has_state_index(m):
    return (
        hasattr(m, "state_index")
        and isinstance(m.state_index, StateIndex))
def get_modules_to_update(loss):
    return tuple([
            x
            for x in jax.tree.leaves(loss, is_leaf=has_state_index)
            if has_state_index(x)
            ])
get_modules_to_update(loss_fn)

(StatefulLoss(state_index=StateIndex(id=np.int32(50340)), state=5.0),
 StatefulLoss(state_index=StateIndex(id=np.int32(26007)), state=0.0))

To actually perform in-place replacement, jax does not have easy support, luckily equinox does with the `tree_at` function.

In [7]:
def replace_fn(module):
    return module.update_state(state_index_to_update[int(module.state_index.id)])
loss_fn = eqx.tree_at(where=get_modules_to_update, 
                      pytree=loss_fn, 
                      replace_fn=replace_fn)

Finally verify that the state update has been done

In [8]:
get_modules_to_update(loss_fn)

(StatefulLoss(state_index=StateIndex(id=np.int32(50340)), state=6.0),
 StatefulLoss(state_index=StateIndex(id=np.int32(26007)), state=1.0))

The implementation of the `has_state_index` and `is_state_update` is stored in `mosaic/source/common.py`, while the `update_state` function is in the `mosaic/source/optimizer.py`.