In [1]:
%load_ext autoreload

%autoreload 2

In [None]:
from IMLCV.configs.config_general import config

# config(
#     env="local",
#     local_ref_threads=2,
#     max_threads_local=10,
# )

N_TRAIN = 32

config(
    account="2024_117",
    singlepoint_nodes=4,
    default_on_threads=False,
    cpu_cluster="cpu_milan_rhel_9",
    training_cores=N_TRAIN,
    walltime_training="04:00:00",
    walltime_ref="02:00:00",
    max_threads_local=10,
)

# config(

#     singlepoint_nodes=4,
#     walltime="12:00:00",
#     bootstrap=False,
#     training_cores=32,

# )

In [3]:
from pathlib import Path

from IMLCV.base.rounds import Rounds
from IMLCV.examples.example_systems import CsPbI3_MACE_lattice
from IMLCV.scheme import Scheme

folder = Path("perovskites") / "CsPbI3_cell_002"

if folder.exists():
    scheme = Scheme(rounds=Rounds.create(folder=folder, copy=False, new_folder=False))

else:
    md, refs = CsPbI3_MACE_lattice()

    scheme = Scheme.from_refs(
        mde=md,
        refs=refs,
        folder=folder,
        steps=1000,
    )

In [4]:
# _c,_r= 3,4

# for i in scheme.rounds._i_vals(c=_c,r=_r):
#     scheme.rounds.finish_data(c=_c,r=_r,i=i)


In [5]:
from IMLCV.base.UnitsConstants import angstrom, boltzmann, kelvin, kjmol

300 * kelvin * boltzmann / kjmol

2.494341595946812

In [6]:
scheme.rounds.unzip_cv(cv=1)

unzipping cv=1


In [7]:
# chunk_size = 100

# n = [15, 10, 6][scheme.bias.collective_variable.n - 1]

# print(f"{n=}")

# max_bias = 100 * kjmol
# samples_per_bin = 200
# min_samples_per_bin = 5

# n_max = (2 * n) ** (scheme.bias.collective_variable.n)

# print(f"{n=} {n_max=}")

In [8]:
from IMLCV.base.UnitsConstants import kjmol

r_cut = 6.0 * angstrom

steps = 1000
chunk_size = None
macro_chunk = N_TRAIN * 16  # sampes per worker
macro_chunk_nl = N_TRAIN * 128
samples_per_bin = 50
min_samples_per_bin = 10
T_Scale = 10
koopman = True
eps = 0.15  # 10 percent overlap
max_bias = 100 * kjmol
num_cv_rnds = 2
lag_n = 30
koopman_wham = True
n_max_descriptors = 1
l_max_descriptor = 1
alpha_rematch = 0.5
ncv = 2
min_cv = 2
direct_bias = False

max_cv_basis_fun = 2000

bias_num_points = 5e4
cv_num_points = 5e4
num_sample_rounds = 5


n_max = 1e3  # 30**3

max_blocks = 400

In [None]:
def update_cv():
    import jax
    import jax.numpy as jnp

    from IMLCV.base.CV import CV, CvTrans
    from IMLCV.implementations.CV import get_sinkhorn_divergence_2, sb_descriptor, un_atomize
    from IMLCV.implementations.CvDiscovery import TransformerMAF

    jax.clear_caches()

    print("getting descriptor")

    descriptor = sb_descriptor(
        r_cut=r_cut,
        n_max=n_max_descriptors,
        l_max=l_max_descriptor,
        reshape=True,
        reduce=True,
        mul_Z=True,
    )

    rounds = scheme0.rounds

    dlo_0 = rounds.data_loader(
        cv_round=0,
        start=0,
        weight=False,
        new_r_cut=r_cut,
        time_series=True,
        lag_n=20,
    )

    # print(f"{dlo.nl=} {dlo.sp=}")

    print("computing soap descriptor")

    tot_trans = descriptor

    cv_0, cv_0_t = dlo_0.apply_cv(
        descriptor,
        x=dlo_0.sp,
        x_t=dlo_0.sp_t,
        nl=dlo_0.nl,
        nl_t=dlo_0.nl,
        macro_chunk=macro_chunk,
        verbose=True,
    )

    # print("normalizing")

    print(f"{alpha_rematch=}")

    pi = CV.stack(*[a[-1] for a in cv_0])
    nli = dlo_0.nl
    print("sinkhorn div")
    tr = (
        get_sinkhorn_divergence_2(
            nli=nli,
            pi=pi,
            alpha_rematch=alpha_rematch,
            # normalize=False,
            # scale_eps="std",
            jacobian=False,
            exp_factor=None,
        )
        * un_atomize
    )

    tot_trans *= tr

    cv_1, cv_1_t = dlo_0.apply_cv(
        tr,
        x=cv_0,
        x_t=cv_0_t,
        nl=dlo_0.nl,
        nl_t=dlo_0.nl,
        macro_chunk=macro_chunk,
        verbose=True,
    )

    M = jnp.array([jnp.mean(x.cv, axis=0) for x in cv_1])
    exp_factors = 2 / jnp.min(M + jnp.eye(M.shape[0]) * 1000, axis=1)
    print(f"{exp_factors=}")

    print("sinkhorn div")
    tr = (
        get_sinkhorn_divergence_2(
            nli=nli,
            pi=pi,
            alpha_rematch=alpha_rematch,
            # normalize=False,
            # scale_eps="std",
            jacobian=True,
            exp_factor=exp_factors,
        )
        * un_atomize
    )

    tot_trans *= tr

    cv_0, cv_0_t = dlo_0.apply_cv(
        tr,
        x=cv_0,
        x_t=cv_0_t,
        nl=dlo_0.nl,
        nl_t=dlo_0.nl,
        macro_chunk=macro_chunk,
        verbose=True,
    )

    print(f"{cv_0[0].shape=}")

    from IMLCV.base.rounds import Covariances

    calc_pi = True

    cov = Covariances.create(
        cv_0=cv_0,
        cv_1=cv_0_t,
        nl=dlo_0.nl,
        nl_t=dlo_0.nl,
        # w=w_tot,
        # w_t=w_tot,
        calc_pi=True,
        only_diag=False,
        symmetric=False,
        chunk_size=chunk_size,
        macro_chunk=macro_chunk,
    )

    argmask = jnp.arange(cov.C00.shape[0])

    eps_pre = 1e-8
    auto_cov_threshold = 0.8
    verbose = True
    max_features_pre = 1000

    argmask_pre = jnp.logical_and(
        jnp.diag(cov.C00) / jnp.max(jnp.diag(cov.C00)) > eps_pre**2,
        jnp.diag(cov.C11) / jnp.max(jnp.diag(cov.C11)) > eps_pre**2,
    )

    if verbose:
        print(f"{jnp.sum(argmask_pre)=} {jnp.sum(~argmask_pre)=}  {eps_pre=}")

    if jnp.sum(argmask_pre) == 0:
        print(
            f"WARNING: no modes selectected through argmask pre {jnp.diag(cov.C00)/ jnp.max(jnp.diag(cov.C00))=} {jnp.diag(cov.C11)/ jnp.max(jnp.diag(cov.C11))=}"
        )

    cov.C00 = cov.C00[argmask_pre, :][:, argmask_pre]
    cov.C11 = cov.C11[argmask_pre, :][:, argmask_pre]
    cov.C01 = cov.C01[argmask_pre, :][:, argmask_pre]
    cov.C10 = cov.C10[argmask_pre, :][:, argmask_pre]

    argmask = argmask[argmask_pre]

    auto_cov = jnp.einsum(
        "i,i,i->i",
        jnp.diag(cov.C00) ** (-0.5),
        jnp.diag(cov.C01),
        jnp.diag(cov.C11) ** (-0.5),
    )

    argmask_cov = jnp.argsort(auto_cov, descending=True).reshape(-1)

    print(f"{auto_cov[argmask_cov]=}")

    if auto_cov_threshold is not None:
        argmask_cov = argmask_cov[auto_cov[argmask_cov] > auto_cov_threshold]

    if max_features_pre is not None:
        if argmask_cov.shape[0] > max_features_pre:
            argmask_cov = argmask_cov[:max_features_pre]
            print(f"reducing argmask_cov to {max_features_pre}")

    cov.C00 = cov.C00[argmask_cov, :][:, argmask_cov]
    cov.C11 = cov.C11[argmask_cov, :][:, argmask_cov]
    cov.C01 = cov.C01[argmask_cov, :][:, argmask_cov]
    cov.C10 = cov.C10[argmask_cov, :][:, argmask_cov]

    argmask = argmask[argmask_cov]

    print(f"{auto_cov[argmask_cov]=}")

    print(f"{argmask_cov=}")

    from IMLCV.base.rounds import DataLoaderOutput

    km_tr = CvTrans.from_cv_function(
        DataLoaderOutput._transform,
        static_argnames=["add_1", "add_1_pre"],
        add_1=False,
        add_1_pre=False,
        q=None,
        pi=None,
        argmask=argmask,
    )

    tot_flow = descriptor * tr * km_tr

    print("testing comp")

    cv_out, cv_out_t = dlo_0.apply_cv(
        tot_flow,
        x=dlo_0.sp,
        x_t=dlo_0.sp,
        nl=dlo_0.nl,
        nl_t=dlo_0.nl,
        macro_chunk=macro_chunk,
        verbose=True,
    )

    start = 1

    if scheme0.rounds.cv == 0:
        start = 0

    scheme0.update_CV(
        dlo_kwargs=dict(
            out=cv_num_points,
            num_cv_rounds=1,
            time_series=True,
            new_r_cut=r_cut,
            num=num_cv_rnds + 1,
            start=start,
            lag_n=lag_n,
            split_data=False,
            chunk_size=chunk_size,
            only_finished=True,
            T_scale=T_Scale,
            macro_chunk=macro_chunk,
            samples_per_bin=samples_per_bin,
            min_samples_per_bin=min_samples_per_bin,
            macro_chunk_nl=macro_chunk_nl,
            verbose=True,
            n_max=n_max,
            weight=True,
            only_update_nl=True,
            # reweight_to_fes=True,
            reweight_inverse_bincount=True,
            scale_times=False,
            # max_point_frac=0.6,
        ),
        transformer=TransformerMAF(
            outdim=ncv,
            descriptor=tot_flow,
            pre_scale=False,
            correct_bias=True,
            use_ground_bias=True,
            T_scale=T_Scale,
            koopman_weighting=False,
            method="tcca",
            solver="eig",
            max_features=max_cv_basis_fun,
            max_features_pre=max_cv_basis_fun,  # if scheme0.rounds.cv > 0 else 100,
            trans=None,
            eps=1e-10,
            eps_pre=1e-10,
        ),
        chunk_size=chunk_size,
        plot=True,
        new_r_cut=r_cut,
        save_samples=True,
        save_multiple_cvs=False,
        test=False,
        percentile=percentile,  # get 99.9% of the points
        max_bias=max_bias,
        macro_chunk=macro_chunk,
        verbose=True,
        n_max=n_max,
    )

In [10]:
import jax.numpy as jnp


def sample(eps=0.1):
    cv_dim = scheme.rounds.get_collective_variable().n

    n_umbrella = int(jnp.floor(max_blocks ** (1 / cv_dim)))

    if n_umbrella > 12:
        n_umbrella = 12

    print(f"n_umbrella: {n_umbrella}")

    scheme.inner_loop(
        n=n_umbrella,
        rnds=num_sample_rounds,  # if scheme.rounds.cv > 1 else 3,
        steps=1000,
        init=0,
        eps_umbrella=eps,
        fes_bias_rnds=num_cv_rnds,
        chunk_size=chunk_size,
        macro_chunk=macro_chunk,
        samples_per_bin=samples_per_bin,
        min_samples_per_bin=min_samples_per_bin,
        max_bias=max_bias,
        n_max_fes=n_max,
        convergence_kl=0.05,
        thermolib=False,
        T_scale=T_Scale,
        koopman=koopman,
        koopman_wham=koopman_wham,
        out=bias_num_points,
        direct_bias=direct_bias,
        # first_round_without_bias=scheme.rounds.cv == 1,
    )

In [11]:
def get_fes_bias():
    bias = scheme.FESBias(
        plot=True,
        samples_per_bin=samples_per_bin,
        min_samples_per_bin=min_samples_per_bin,
        chunk_size=chunk_size,
        macro_chunk=macro_chunk,
        num_rnds=num_cv_rnds,
        vmax=max_bias,
        max_bias=max_bias,
        only_finished=True,
        n_max=n_max,
        thermolib=False,
        out=bias_num_points,
        T_scale=T_Scale,
        # divide_by_histogram=True,
        lag_n=lag_n,
        koopman_wham=koopman_wham,
        koopman=koopman,
        direct_bias=direct_bias,
    )

    scheme.rounds.add_round(bias=bias)

In [None]:
if scheme.rounds.cv == 0:
    update_cv()

for i in range(10):
    sample(eps=eps)
    update_cv()

n_umbrella: 12
cv_round=1
not converged kl_div=Array(2.18076735, dtype=float64)
i_0=5
done
getting descriptor
iterating low=0 high=2 num=4  start=0 stop=1
padded_vmap done
initializing neighbour list with nn=23 new_nxyz=Array([1., 1., 1.], dtype=float64)
...
setting weights to one!
len(sp) = 3
new choice
selected 2943 out of 2943 data points len(out_reweights)=3 len(out_rhos)=3
gathering data
...
len(out_sp) = 3 
tau = 40.00 fs, lag_time*timestep = 40.00 fs
computing soap descriptor
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 14:47:31.630......
finished at: 14:48:15.926
outside _apply
sinkhorn div
converting nli to info
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 14:48:35.163......
finished at: 14:48:46.894
outside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 14:48:47.044.compiled f
compiled chunk func
....recompiled f for last chunk
recompiled chunk func

Process WorkQueue-Submit-Process:
Process WorkQueue-Submit-Process:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/site-packages/parsl/process_loggers.py", line 26, in wrapped
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwar



KeyboardInterrupt: 

In [11]:
get_fes_bias()

estimating bias from koopman Theory!


  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.
Using Materials Project MACE for MACECalculator with /dodrio/scratch/users/vsc43693/.cache/mace/20231203mace128L1_epoch199model
Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization.


  torch.load(f=model_path, map_location=device)


ERROR:concurrent.futures:exception calling callback for <Future at 0x14a3eea51310 state=finished returned NoneType>
Traceback (most recent call last):
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/concurrent/futures/_base.py", line 340, in _invoke_callbacks
    callback(self)
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/site-packages/parsl/dataflow/dflow.py", line 407, in handle_exec_update
    self._complete_task(task_record, States.exec_done, res)
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/site-packages/parsl/dataflow/dflow.py", line 590, in _complete_task
    task_record['app_fu'].set_result(result)
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/concurrent/futures/_base.py", line 544, in set_result
    raise InvalidStateError('{}: {!r}'.format(self._state, self))
concurrent.futures._base.InvalidStateError: FINISHED: <AppFuture

In [None]:
update_cv()

getting descriptor
iterating low=0 high=2 num=4  start=0 stop=1
padded_vmap done
initializing neighbour list with nn=23 new_nxyz=Array([1., 1., 1.], dtype=float64)
...
setting weights to one!
len(sp) = 3
new choice
selected 2943 out of 2943 data points len(out_reweights)=3 len(out_rhos)=3
gathering data
...
len(out_sp) = 3 
tau = 40.00 fs, lag_time*timestep = 40.00 fs
computing soap descriptor
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 10:22:15.884......
finished at: 10:23:06.404
outside _apply
sinkhorn div
converting nli to info
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 10:23:15.071......
finished at: 10:23:20.392
outside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 10:23:20.533.compiled f
compiled chunk func
....recompiled f for last chunk
recompiled chunk func for last chunk

finished at: 10:23:20.908
jnp.sum(argmask_pre)=Array(2280, dtype=int64) jn

DependencyError: Dependency failure for task 3. The representative cause is via DataFuture from task 2

In [18]:
import jax

from IMLCV.base.CV import CV, CvTrans
from IMLCV.implementations.CV import get_sinkhorn_divergence_2, sb_descriptor, un_atomize

jax.clear_caches

print("getting descriptor")

descriptor = sb_descriptor(
    r_cut=r_cut,
    n_max=n_max_descriptors,
    l_max=l_max_descriptor,
    reshape=True,
    reduce=True,
    mul_Z=True,
)

rounds = scheme.rounds

dlo_0 = rounds.data_loader(
    cv_round=0,
    start=0,
    weight=False,
    new_r_cut=r_cut,
    time_series=True,
    lag_n=20,
)

# print(f"{dlo.nl=} {dlo.sp=}")

print("computing soap descriptor")

cv_0, cv_0_t = dlo_0.apply_cv(
    descriptor,
    x=dlo_0.sp,
    x_t=dlo_0.sp_t,
    nl=dlo_0.nl,
    nl_t=dlo_0.nl,
    macro_chunk=macro_chunk,
    verbose=True,
)

getting descriptor
iterating low=0 high=2 num=4  start=0 stop=1
padded_vmap done
initializing neighbour list with nn=23 new_nxyz=Array([1., 1., 1.], dtype=float64)
...
setting weights to one!
len(sp) = 3
new choice
selected 2943 out of 2943 data points len(out_reweights)=3 len(out_rhos)=3
gathering data
...
len(out_sp) = 3 
tau = 40.00 fs, lag_time*timestep = 40.00 fs
computing soap descriptor
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 16:32:03.971......
finished at: 16:32:45.866
outside _apply


In [19]:
jax.clear_caches()

print("sinkhorn div")
tr = (
    get_sinkhorn_divergence_2(
        nli=dlo_0.nl,
        pi=CV.stack(*[a[-1] for a in cv_0]),
        alpha_rematch=alpha_rematch,
        jacobian=False,
        exp_factor=None,
    )
    * un_atomize
)

cv_1, cv_1_t = dlo_0.apply_cv(
    tr,
    x=cv_0,
    x_t=cv_0_t,
    nl=dlo_0.nl,
    nl_t=dlo_0.nl,
    macro_chunk=macro_chunk,
    verbose=True,
)

sinkhorn div
converting nli to info
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 16:33:00.754......
finished at: 16:33:03.260
outside _apply


In [30]:
M

Array([[7.54558606e-05, 9.03382586e-03, 1.85879392e-02],
       [9.84758774e-03, 1.25990662e-04, 5.64509641e-03],
       [1.95022021e-02, 6.08752980e-03, 8.93969226e-05]], dtype=float64)

Array([553.4753575 , 885.7244648 , 821.35121558], dtype=float64)

In [None]:
jnp.exp(-2)

Array(0.13533528, dtype=float64, weak_type=True)

Process WorkQueue-Submit-Process:
Process WorkQueue-Submit-Process:
Process WorkQueue-Submit-Process:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/dodrio/scratch/projects/2024_026/IMLCV/micromamba/envs/py312/lib/python3.12/site-packages/parsl/process_loggers.py", line 26, in wrapped
    r = fun

In [39]:
M = jnp.array([jnp.mean(x.cv, axis=0) for x in cv_1])
exp_factors = 2 / jnp.min(M + jnp.eye(M.shape[0]) * 1000, axis=1)
exp_factors

jax.clear_caches()

print("sinkhorn div")
tr = (
    get_sinkhorn_divergence_2(
        nli=dlo_0.nl,
        pi=CV.stack(*[a[-1] for a in cv_0]),
        alpha_rematch=alpha_rematch,
        jacobian=False,
        exp_factor=exp_factors,
    )
    * un_atomize
)

cv_1, cv_1_t = dlo_0.apply_cv(
    tr,
    x=cv_0,
    x_t=cv_0_t,
    nl=dlo_0.nl,
    nl_t=dlo_0.nl,
    macro_chunk=macro_chunk,
    verbose=True,
)

sinkhorn div
converting nli to info
inside _apply
aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 16:53:42.978......
finished at: 16:53:45.366
outside _apply


In [38]:
jnp.array([jnp.mean(x.cv, axis=0) for x in cv_1])

Array([[9.59319695e-01, 3.40810924e-04, 2.39086967e-07],
       [4.38505056e-03, 8.96033999e-01, 1.00793178e-02],
       [2.15833935e-05, 4.78545545e-03, 9.30003310e-01]], dtype=float64)

In [None]:
from IMLCV.base.rounds import Covariances

calc_pi = True

cov = Covariances.create(
    cv_0=cv_1,
    cv_1=cv_1_t,
    nl=dlo_0.nl,
    nl_t=dlo_0.nl,
    # w=w_tot,
    # w_t=w_tot,
    calc_pi=True,
    only_diag=False,
    symmetric=False,
    chunk_size=chunk_size,
    macro_chunk=macro_chunk,
)

argmask = jnp.arange(cov.C00.shape[0])

eps_pre = 1e-10
auto_cov_threshold = 0.1
verbose = True
max_features_pre = 1000

argmask_pre = jnp.logical_and(
    jnp.diag(cov.C00) / jnp.max(jnp.diag(cov.C00)) > eps_pre**2,
    jnp.diag(cov.C11) / jnp.max(jnp.diag(cov.C11)) > eps_pre**2,
)

if verbose:
    print(f"{jnp.sum(argmask_pre)=} {jnp.sum(~argmask_pre)=}  {eps_pre=}")

if jnp.sum(argmask_pre) == 0:
    print(
        f"WARNING: no modes selectected through argmask pre {jnp.diag(cov.C00)/ jnp.max(jnp.diag(cov.C00))=} {jnp.diag(cov.C11)/ jnp.max(jnp.diag(cov.C11))=}"
    )

cov.C00 = cov.C00[argmask_pre, :][:, argmask_pre]
cov.C11 = cov.C11[argmask_pre, :][:, argmask_pre]
cov.C01 = cov.C01[argmask_pre, :][:, argmask_pre]
cov.C10 = cov.C10[argmask_pre, :][:, argmask_pre]


argmask = argmask[argmask_pre]


auto_cov = jnp.einsum(
    "i,i,i->i",
    jnp.diag(cov.C00) ** (-0.5),
    jnp.diag(cov.C01),
    jnp.diag(cov.C11) ** (-0.5),
)
argmask_cov = jnp.argsort(auto_cov, descending=True).reshape(-1)

if auto_cov_threshold is not None:
    argmask_cov = argmask_cov[auto_cov[argmask_cov] > auto_cov_threshold]

if max_features_pre is not None:
    if argmask_cov.shape[0] > max_features_pre:
        argmask_cov = argmask_cov[:max_features_pre]
        print(f"reducing argmask_cov to {max_features_pre}")

cov.C00 = cov.C00[argmask_cov, :][:, argmask_cov]
cov.C11 = cov.C11[argmask_cov, :][:, argmask_cov]
cov.C01 = cov.C01[argmask_cov, :][:, argmask_cov]
cov.C10 = cov.C10[argmask_cov, :][:, argmask_cov]


argmask = argmask[argmask_cov]


print(f"{auto_cov[argmask_cov]=}")

print(f"{argmask_cov=}")


from IMLCV.base.rounds import DataLoaderOutput

km_tr = CvTrans.from_cv_function(
    DataLoaderOutput._transform,
    static_argnames=["add_1", "add_1_pre"],
    add_1=False,
    add_1_pre=False,
    q=None,
    pi=None,
    argmask=argmask,
)


tot_flow = descriptor * tr * km_tr  # + LatticeInvariants2

print("testing comp")

cv_out, cv_out_t = dlo_0.apply_cv(
    tot_flow,
    x=dlo_0.sp,
    x_t=dlo_0.sp_t,
    nl=dlo_0.nl,
    nl_t=dlo_0.nl,
    macro_chunk=macro_chunk,
    verbose=True,
)

print(f"{cv_out[0].shape=}")

aplying cv func to 5 chunks of size 512 + remainder of size 383 
start time: 14:44:30.279.compiled f
compiled chunk func
....recompiled f for last chunk
recompiled chunk func for last chunk

finished at: 14:44:31.408
jnp.sum(argmask_pre)=Array(2163, dtype=int64) jnp.sum(~argmask_pre)=Array(0, dtype=int64)  eps_pre=1e-10
reducing argmask_cov to 1000
auto_cov[argmask_cov]=Array([0.99982842, 0.99963379, 0.99927012, 0.99927004, 0.99926929,
       0.99926907, 0.99926862, 0.99926814, 0.99926791, 0.99926588,
       0.99923023, 0.99922441, 0.99921927, 0.99921103, 0.99920975,
       0.99919581, 0.99919532, 0.99919523, 0.99919511, 0.9991946 ,
       0.99919429, 0.99919234, 0.99919009, 0.9991795 , 0.99917918,
       0.9991769 , 0.99917619, 0.99917482, 0.99917381, 0.99917211,
       0.99917052, 0.99916186, 0.99916054, 0.99912876, 0.99910869,
       0.99885998, 0.99880034, 0.99878929, 0.99876088, 0.99876005,
       0.9987591 , 0.99875872, 0.99875859, 0.99875834, 0.99875832,
       0.99875788, 0.998

KeyboardInterrupt: 