In [38]:
import jax.numpy as jnp
import jax
import jax.lax as lax

In [10]:

schedule = jnp.array([[0,1,0,0,0,1],
                      [0,0,1,0,1,0],
                      [0,1,0,0,1,0],
                      [1,0,0,0,0,1],
                      [0,1,0,0,1,0],
                      [1,0,1,0,0,0]]).T


@jax.jit
def all_disconnect_durations(schedule: jnp.ndarray,
                             time_step: float,
                             T: int):
    """
    schedule:   (T, K) array of 0/1 or booleans
    time_step:  scalar float
    T:          number of timesteps in `schedule`
    returns: 
      durations:        (T-1, K) float array
      disconnect_mask:  (T-1, K) bool array
    """
    # 1) diff along time → shape (T-1, K)
    diff = jnp.diff(schedule.astype(jnp.int32), axis=0)
    T1, K = diff.shape

    # 2) build an index grid [0,1,...,T1-1] along axis-0
    idx = jnp.arange(T1)[:, None]                        # (T1,1)

    # 3) for each (i,k): pos_idx = i if diff[i,k]==+1 else sentinel=T1
    pos_idx = jnp.where(diff ==  1, idx, T1)             # (T1,K)

    # 4) compute suffix‐minimum over pos_idx by doing a prefix‐min on the reversed array:
    #    suffix_min_rev[j,:] = min(pos_idx_rev[:j+1, :], axis=0)
    suffix_min_rev = jax.lax.associative_scan(
        lambda x, y: jnp.minimum(x, y),
        pos_idx[::-1, :],
        axis=0
    )
    suffix_min = suffix_min_rev[::-1, :]                 # (T1,K)

    # 5) we only care about “next reconnect AFTER index i”, so shift up one:
    sentinel_row = jnp.full((1, K), T1)
    suffix_after = jnp.concatenate([suffix_min[1:], sentinel_row], axis=0)  # (T1,K)

    # 6) mask where we actually disconnected
    disconnect_mask = (diff == -1)                       # (T1,K)

    # 7) did we see a reconnect for each (i,k)?
    reconnected = (suffix_after < T1)

    # 8) build raw durations and apply fallbacks:
    raw_dur  = (suffix_after - idx) * time_step
    fallback = T * time_step
    # if disconnected & reconnected → raw_dur
    # if disconnected & NOT reconnected → fallback
    # else → 0
    durations = jnp.where(disconnect_mask & reconnected,
                          raw_dur,
                          0.0)
    durations = jnp.where(disconnect_mask & ~reconnected,
                          fallback,
                          durations)

    return durations, disconnect_mask

all_disconnect_durations(schedule, time_step=0.1, T=schedule.shape[0])

(Array([[0. , 0. , 0. , 0.4, 0. , 0.1],
        [0.3, 0. , 0.2, 0. , 0.2, 0. ],
        [0. , 0.1, 0. , 0. , 0. , 0.6],
        [0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0.6, 0.6, 0. , 0.6, 0. ]], dtype=float32, weak_type=True),
 Array([[False, False, False,  True, False,  True],
        [ True, False,  True, False,  True, False],
        [False,  True, False, False, False,  True],
        [False, False, False, False, False, False],
        [False,  True,  True, False,  True, False]], dtype=bool))

In [63]:
key = jax.random.PRNGKey(0)
v = jax.random.choice(key, jnp.array([0,1]), shape=(1000,420), p=jnp.array([0.9, 0.1]))

def disconnect_time(v: jnp.ndarray, time_step) -> jnp.ndarray:
    # Boolean mask for zeros
    is_zero = (v == 0)

    # Scanning function to compute lengths of consecutive zeros
    def scan_fn(carry, x):
        current_run = jnp.where(x, carry + 1, 0)
        return current_run, current_run

    # Scan over the boolean mask to get lengths of runs of zeros
    _, runs = lax.scan(scan_fn, jnp.array(0), is_zero)

    
    # Check the shifted version of the boolean mask to find run-ends, where a run of zeros ends
    shifted = jnp.concatenate([is_zero[1:], jnp.array([False])])
    run_ends = is_zero & (~shifted)

    # Zero out everything except at run-ends
    run_lengths = runs * run_ends

    return run_lengths*time_step


disconnect_times = jax.jit(jax.vmap(disconnect_time, in_axes=(0, None)))(v.T, 0.1)


Array([[0.1       , 1.        , 1.        , ..., 1.        , 1.        ,
        0.3       ],
       [1.        , 0.2       , 1.        , ..., 1.        , 1.        ,
        1.2       ],
       [0.1       , 1.        , 1.        , ..., 1.        , 1.        ,
        1.2       ],
       ...,
       [1.        , 0.2       , 1.        , ..., 1.        , 1.        ,
        0.8       ],
       [0.1       , 1.        , 1.        , ..., 1.        , 1.        ,
        1.1       ],
       [1.        , 1.        , 1.        , ..., 1.        , 1.        ,
        0.90000004]], dtype=float32, weak_type=True)