In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import math
from importlib import reload

import torch

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))

import mars
from mars import spin_model, spectra_manager, constants, population, concat

In [2]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import mars.population.transform as transform
from mars.population.contexts import Context, SummedContext
from mars import spin_model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
dtype = torch.float64

# 1. Create Samples for Checking

In [4]:
def create_samples():
    g_tensor_1 = spin_model.Interaction((2.02, 2.04, 2.06), dtype=dtype)
    zfs_1 = spin_model.DEInteraction((200 * 1e6, 50 * 1e6), dtype=dtype)  # D=200 MHz, E=50 MHz
    g_tensor_2 = spin_model.Interaction((2.02, 2.04, 2.06), dtype=dtype)
    zfs_2 = spin_model.DEInteraction((200 * 1e6, 50 * 1e6), dtype=dtype)
    
    base_spin_system = spin_model.SpinSystem(
        electrons=[1.0],
        g_tensors=[g_tensor_1],
        electron_electron=[(0, 0, zfs_1)],
        dtype=dtype,
    )
    sample_1 = spin_model.MultiOrientedSample(
        base_spin_system=base_spin_system,
        gauss=0.001,
        lorentz=0.001,
        mesh=(4, 4),
        dtype=dtype,
    )


    g_tensor = spin_model.Interaction((2.02, 2.14, 2.16), dtype=dtype)
    zfs = spin_model.DEInteraction((200 * 1e6, 70 * 1e6), dtype=dtype)
    base_spin_system = spin_model.SpinSystem(
        electrons=[1.0],
        g_tensors=[g_tensor_2],
        electron_electron=[(0, 0, zfs_2)],
        dtype=dtype
    )
    sample_2 = spin_model.MultiOrientedSample(
        base_spin_system=base_spin_system,
        gauss=0.001,
        lorentz=0.001,
        mesh=(4, 4),
        dtype=dtype,
)
    return sample_1, sample_2

def get_eigen_basis(sample, filed: float):
    magnetic_field = torch.tensor(filed)
    F, _, _, Gz = sample.get_hamiltonian_terms()
    H = F + Gz * magnetic_field
    H = H.unsqueeze(-3)
    values, vectors = torch.linalg.eigh(H)
    return vectors

# 2. Create contexts functions

In [5]:
def create_contexts_type_1(sample_1, sample_2, basis_1, basis_2):
    context_1 = Context(
        basis=basis_1,
        sample=sample_1,
        init_populations=[0.0, 0.0, 1.0],
        out_probs=torch.tensor([123.3, 3123, 111.0], dtype=dtype),
        free_probs=torch.tensor([[-0.0, 100.0, 0.0], 
                                 [100.0, -0.0, 667.0], 
                                 [0.0, 667.0, -0.0]], dtype=dtype) * 0.0,
        dtype=dtype
    )
    context_2 = Context(
        basis=basis_2,
        sample=sample_2,
        init_populations=[0.0, 1.0, 0.0],
        out_probs=torch.tensor([434.0, 1233.0, 4343.0], dtype=dtype),
        free_probs=torch.tensor([[-0.0, 321.0, 123.0], 
                                 [321.0, -0.0, 454.0], 
                                 [123.0, 454.0, -0.0]], dtype=dtype) * 0.0,
        dtype=dtype
    )
    return context_1, context_2
    
def create_contexts_type_2(sample_1, sample_2, basis_1, basis_2):
    context_1 = Context(
        basis=basis_1,
        sample=sample_1,
        init_populations=[0.5, 0.35, 0.15],
        out_probs=[0.0, 200.3, 100.2],
        free_probs=torch.tensor([[-0.0, 100.0, 0.0], 
                                 [100.0, -0.0, 100.0], 
                                 [0.0, 100.0, -0.0]], dtype=dtype),
        dtype=dtype
    )
    context_2 = Context(
        basis=basis_2,
        sample=sample_2,
        init_populations=[0.3, 0.4, 0.3], 
        out_probs=[111.0, 1234.0, 1232.0],
        free_probs=torch.tensor([[0.0, 100.0, 0.0], 
                                 [100.0, 0.0, 200.0], 
                                 [0.0, 200.0, -2.0]], dtype=dtype),
        dtype=dtype
    )
    return context_1, context_2
    
def create_contexts_type_3(sample_1, sample_2, basis_1, basis_2):
    context_1 = Context(
        basis=basis_1,
        sample=sample_1,
        init_populations=[0.5, 0.25, 0.25],
        out_probs=[1000.0, 200.3, 100.2],
        dephasing = [1e4, 1e5, 1e6],
        dtype=dtype
    )
    context_2 = Context(
        basis=basis_2,
        sample=sample_2,
        init_populations=[0.3, 0.4, 0.3], 
        out_probs=[111.0, 1234.0, 1232.0],
        dephasing = [2e5, 1e4, 1e3],
        dtype=dtype
    )
    return context_1, context_2
    
def get_context_pairs(
    creator_func, 
    sample_a, 
    sample_b, 
    use_same_sample: bool = False
) -> list[tuple]:
    """Generate context pairs for different basis combinations."""
    s1 = sample_a if use_same_sample else sample_a
    s2 = sample_a if use_same_sample else sample_b

    return [
        creator_func(s1, s2, "xyz", "zeeman"),
        creator_func(s1, s2, "zfs", "multiplet")
    ]


context_creators = [
        create_contexts_type_1,
        create_contexts_type_2,
        create_contexts_type_3,
    ]

# 3. Create check Context functions

In [6]:
import torch
import numpy as np
import mars.population.transform as transform
from mars.population.contexts import Context, SummedContext
from mars import spin_model

def check_multiplication_correctness(
    context_1: Context,
    context_2: Context,
    full_system_vectors_1: torch.Tensor,
    full_system_vectors_2: torch.Tensor,
    atol: float = 1e-6,
    rtol: float = 1e-5
) -> None:
    """
    Validate Kronecker product (@) of contexts representing independent subsystems.
    
    Physical rules tested:
    - Hilbert space: H_total = H₁ ⊗ H₂  →  dim_total = dim₁ × dim₂
    - Initial state: ρ_total = ρ₁ ⊗ ρ₂  (Kronecker product)
    - Dynamics: K_total = K₁ ⊗ I₂ + I₁ ⊗ K₂  (sum of local operators)
    - Basis transformation preserves unitarity
    """
    full_system_vectors_tot = transform.batched_kron(
        full_system_vectors_1, full_system_vectors_2
    )
    mul_context = context_1 @ context_2

    # 1. Dimensionality check
    dim1 = context_1.spin_system_dim
    dim2 = context_2.spin_system_dim
    dim_tot = mul_context.spin_system_dim
    assert dim_tot == dim1 * dim2, \
        f"Dimension mismatch: {dim1} ⊗ {dim2} = {dim_tot} (expected {dim1 * dim2})"
    """
    # 2. Initial populations: Kronecker product
    pop1 = context_1.get_transformed_init_populations(full_system_vectors_1, normalize=False)
    pop2 = context_2.get_transformed_init_populations(full_system_vectors_2, normalize=False)
    pop_tot = mul_context.get_transformed_init_populations(full_system_vectors_tot, normalize=False)
    
    if pop1 is not None and pop2 is not None:
        expected_pop = transform.batched_kron(pop1.unsqueeze(-1), pop2.unsqueeze(-1)).squeeze(-1)
        assert torch.allclose(pop_tot, expected_pop, atol=atol, rtol=rtol), \
            "Initial populations don't combine via Kronecker product"
        assert torch.allclose(pop_tot.sum(-1), torch.ones_like(pop_tot.sum(-1)), atol=1e-5, rtol=rtol), \
            "Total probability not conserved after Kronecker product"
    
    dens1 = context_1.get_transformed_init_density(full_system_vectors_1)
    dens2 = context_2.get_transformed_init_density(full_system_vectors_2)
    dens_tot = mul_context.get_transformed_init_density(full_system_vectors_tot)
    
    if dens1 is not None and dens2 is not None:
        # Expected: ρ_tot = ρ₁ ⊗ ρ₂
        expected_dens = transform.batched_kron(dens1, dens2)
        assert torch.allclose(dens_tot, expected_dens, atol=atol, rtol=rtol), \
            "Density matrices don't combine via Kronecker product"
        assert torch.allclose(dens_tot, dens_tot.conj().transpose(-2, -1), atol=1e-5, rtol=rtol), \
            "Composite density matrix not Hermitian"
    

    # 3. Out probs
    probs1 = context_1.get_transformed_out_probs(full_system_vectors_1)
    probs2 = context_2.get_transformed_out_probs(full_system_vectors_2)
    probs_tot = mul_context.get_transformed_out_probs(full_system_vectors_tot)
    
    if probs1 is not None and probs2 is not None:
        I1 = torch.ones(dim1, device=probs1.device, dtype=probs1.dtype)
        I2 = torch.ones(dim2, device=probs2.device, dtype=probs2.dtype)
        
        # Broadcast to common batch shape
        batch_shape = torch.broadcast_shapes(probs1.shape[:-1], probs2.shape[:-1])
        probs1_exp = probs1.expand(batch_shape + (dim1, ))
        probs2_exp = probs2.expand(batch_shape + (dim2,))
        I1_exp = I1.expand(batch_shape + (dim1,))
        I2_exp = I2.expand(batch_shape + (dim2,))
        
        expected_probs = transform.batched_kron(probs1_exp, I2_exp) + \
                         transform.batched_kron(I1_exp, probs2_exp)
        assert torch.allclose(probs_tot, expected_probs, atol=atol, rtol=rtol), \
            "Out matrices don't combine as K₁⊗I + I⊗K₂"
        
    # 4. Free probs

    probs1 = context_1.get_transformed_free_probs(full_system_vectors_1)
    probs2 = context_2.get_transformed_free_probs(full_system_vectors_2)
    probs_tot = mul_context.get_transformed_free_probs(full_system_vectors_tot)

    if probs1 is not None and probs2 is not None:
        I1 = torch.eye(dim1, device=probs1.device, dtype=probs1.dtype)
        I2 = torch.eye(dim2, device=probs2.device, dtype=probs2.dtype)
        
        # Broadcast to common batch shape
        batch_shape = torch.broadcast_shapes(probs1.shape[:-2], probs2.shape[:-2])
        probs1_exp = probs1.expand(batch_shape + (dim1, dim1))
        probs2_exp = probs2.expand(batch_shape + (dim2, dim2))
        I1_exp = I1.expand(batch_shape + (dim1, dim1))
        I2_exp = I2.expand(batch_shape + (dim2, dim2))
        
        expected_probs = transform.batched_kron(probs1_exp, I2_exp) + \
                         transform.batched_kron(I1_exp, probs2_exp)
        print((probs_tot - expected_probs)[-1, 0])

        assert torch.allclose(probs_tot, expected_probs, atol=atol, rtol=rtol), \
            "Rate matrices don't combine as K₁⊗I + I⊗K₂"
     """
    
    # 5. Superoperators: Liouville-space Kronecker sum
    superop1 = context_1.get_transformed_free_superop(full_system_vectors_1)
    superop2 = context_2.get_transformed_free_superop(full_system_vectors_2)
    superop_tot = mul_context.get_transformed_free_superop(full_system_vectors_tot)
    
    if superop1 is not None and superop2 is not None:
        # Expected: R_tot = R₁ ⊗ I_Liouv + I_Liouv ⊗ R₂
        # where I_Liouv has dimension (dim², dim²)
        dim1_sq = dim1 * dim1
        dim2_sq = dim2 * dim2
        I1_liouv = torch.eye(dim1_sq, device=superop1.device, dtype=superop1.dtype)
        I2_liouv = torch.eye(dim2_sq, device=superop2.device, dtype=superop2.dtype)
        
        batch_shape = torch.broadcast_shapes(superop1.shape[:-2], superop2.shape[:-2])
        superop1_exp = superop1.expand(batch_shape + (dim1_sq, dim1_sq))
        superop2_exp = superop2.expand(batch_shape + (dim2_sq, dim2_sq))
        I1_exp = I1_liouv.expand(batch_shape + (dim1_sq, dim1_sq))
        I2_exp = I2_liouv.expand(batch_shape + (dim2_sq, dim2_sq))
        
        expected_superop = transform.batched_kron(superop1_exp, I2_exp) + \
                           transform.batched_kron(I1_exp, superop2_exp)
        
        expected_superop = transform.reshape_superoperator_tensor_to_kronecker_basis(
            expected_superop, subsystem_dims=[dim1, dim2]
        )
        assert torch.allclose(superop_tot, expected_superop, atol=atol, rtol=rtol), \
            "Superoperators don't combine via Liouville Kronecker sum"

    if mul_context.basis is not None:
        U = mul_context.basis
        U_dag = U.conj().transpose(-2, -1)
        identity = torch.eye(U.shape[-1], device=U.device, dtype=U.dtype)
        prod = U_dag @ U
        assert torch.allclose(prod, identity.expand_as(prod), atol=1e-5), \
            "Composite basis transformation not unitary"
    
    print("Multiplication test passed: dimensions, populations, densities, rates, superoperators")


def check_sum_correctness(
    context_1: Context,
    context_2: Context,
    full_system_vectors: torch.Tensor,
    atol: float = 1e-6,
    rtol: float = 1e-5
) -> None:
    """
    Validate addition (+) of contexts representing parallel relaxation mechanisms.
    
    Physical rules tested:
    - Dimensions remain unchanged (same Hilbert space)
    - Populations sum (if both defined)
    - Rate matrices sum element-wise
    - Superoperators sum element-wise
    - Detailed balance preserved in summed thermal rates
    """
    sum_context = context_1 + context_2
    
    # 1. Dimensionality preserved
    dim1 = context_1.spin_system_dim
    dim2 = context_2.spin_system_dim
    dim_sum = sum_context.spin_system_dim
    assert dim1 == dim2 == dim_sum, \
        f"Dimension mismatch in sum: {dim1} + {dim2} = {dim_sum}"
    
    pop1 = context_1.get_transformed_init_populations(full_system_vectors, normalize=False)
    pop2 = context_2.get_transformed_init_populations(full_system_vectors, normalize=False)
    pop_sum = sum_context.get_transformed_init_populations(full_system_vectors, normalize=False)
    
    if pop1 is not None and pop2 is not None:
        expected_pop = pop1 + pop2
        assert torch.allclose(pop_sum, expected_pop, atol=atol, rtol=rtol), \
            "Populations don't sum correctly"
    
    dens1 = context_1.get_transformed_init_density(full_system_vectors)
    dens2 = context_2.get_transformed_init_density(full_system_vectors)
    dens_sum = sum_context.get_transformed_init_density(full_system_vectors)
    
    if dens1 is not None and dens2 is not None:
        expected_dens = dens1 + dens2
        assert torch.allclose(dens_sum, expected_dens, atol=atol, rtol=rtol), \
            "Density matrices don't sum correctly"
        
        # Hermiticity preserved after summation
        assert torch.allclose(dens_sum, dens_sum.conj().transpose(-2, -1), atol=1e-5), \
            "Summed density matrix not Hermitian"
        
        # Trace should be sum of individual traces
        trace1 = torch.diagonal(dens1, dim1=-2, dim2=-1).sum(-1)
        trace2 = torch.diagonal(dens2, dim1=-2, dim2=-1).sum(-1)
        trace_sum = torch.diagonal(dens_sum, dim1=-2, dim2=-1).sum(-1)
        assert torch.allclose(trace_sum, trace1 + trace2, atol=1e-5), \
            "Trace not additive under summation"
    
    # 4. Rate matrices sum
    probs1 = context_1.get_transformed_free_probs(full_system_vectors)
    probs2 = context_2.get_transformed_free_probs(full_system_vectors)
    probs_sum = sum_context.get_transformed_free_probs(full_system_vectors)
    
    if probs1 is not None and probs2 is not None:
        expected_probs = probs1 + probs2
        assert torch.allclose(probs_sum, expected_probs, atol=atol, rtol=rtol), \
            "Rate matrices don't sum correctly"
    
    # 5. Superoperators sum
    superop1 = context_1.get_transformed_free_superop(full_system_vectors)
    superop2 = context_2.get_transformed_free_superop(full_system_vectors)
    superop_sum = sum_context.get_transformed_free_superop(full_system_vectors)
    
    if superop1 is not None and superop2 is not None:
        expected_superop = superop1 + superop2
        assert torch.allclose(superop_sum, expected_superop, atol=atol, rtol=rtol), \
            "Superoperators don't sum correctly"
    print("Sum test passed: dimensions, populations, densities, rates, superoperators")


def check_density_transformation_correctness(
    context: Context,
    full_system_vectors: torch.Tensor,
    atol: float = 1e-6,
    rtol: float = 1e-5
) -> None:
    """
    Validate density matrix transformation between bases.
    
    Physical rules tested:
    - Unitary transformation: ρ_new = U ρ_old U†
    - Hermiticity preservation: ρ = ρ†
    - Trace preservation: Tr(ρ) = 1 (for normalized states)
    - Eigenvalue spectrum invariance under unitary transformation
    """
    if context.init_density is None and context.init_populations is None:
        print("Density transformation test skipped: no initial density/populations")
        return
    
    dens_transformed = context.get_transformed_init_density(full_system_vectors)
    
    if dens_transformed is None:
        print("Density transformation test skipped: transformation returned None")
        return
    
    dim = context.spin_system_dim
    
    # 1. Hermiticity check
    assert torch.allclose(
        dens_transformed, 
        dens_transformed.conj().transpose(-2, -1), 
        atol=atol
    ), "Transformed density matrix not Hermitian"
    
    # 2. Trace preservation
    trace = torch.diagonal(dens_transformed, dim1=-2, dim2=-1).sum(-1)
    assert torch.allclose(trace, torch.ones_like(trace), atol=1e-5), \
        f"Trace not preserved: mean={trace.mean().item():.6f} ≠ 1.0"
    
    # 3. Positive semi-definiteness (eigenvalues ≥ 0)
    # Only check if matrix is small enough for eigendecomposition
    if dim <= 16:  # Practical limit for batched eigendecomposition
        eigvals = torch.linalg.eigvalsh(dens_transformed.real)  # Hermitian → real eigenvalues
        assert torch.all(eigvals >= -1e-8), \
            f"Negative eigenvalues found: min={eigvals.min().item():.2e}"
    
    # 4. If basis is identity, transformation should be identity
    if context.basis is not None and context.basis.shape[-1] == dim:
        # Check if basis is approximately identity
        identity = torch.eye(dim, device=context.basis.device, dtype=context.basis.dtype)
        if torch.allclose(context.basis, identity, atol=1e-6):
            # Transformation should leave density unchanged (up to numerical error)
            dens_original = context.init_density
            if dens_original is not None:
                assert torch.allclose(dens_transformed, dens_original, atol=atol, rtol=rtol), \
                    "Identity basis transformation modified density matrix"
    
    
    print("✓ Density transformation test passed: Hermiticity, trace, positivity, consistency")


def direct_sum_batched(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """
    Compute the direct sum (block diagonal) of two batched square matrices.
    
    Args:
        A: Tensor of shape (batch_size, n, n)
        B: Tensor of shape (batch_size, m, m)
    
    Returns:
        Tensor of shape (batch_size, n + m, n + m) with A and B on the diagonal blocks.
    """
    n, n2 = A.shape[-2:]
    m, m2 = B.shape[-2:]
    
    # Validate shapes
    assert n == n2, f"A must be square, got shape {A.shape}"
    assert m == m2, f"B must be square, got shape {B.shape}"
    
    batch_size = A.shape[:-2]
    total_dim = n + m
    
    # Initialize zero tensor for result
    result = torch.zeros(*batch_size, total_dim, total_dim, 
                         dtype=A.dtype, device=A.device)
    
    # Fill top-left block with A
    result[..., :n, :n] = A
    
    # Fill bottom-right block with B
    result[..., n:, n:] = B
    
    return result
    
    
def check_concatenation_correctness(
    context_1: Context,
    context_2: Context,
    full_system_vectors_1: torch.Tensor,
    full_system_vectors_2: torch.Tensor,
    atol: float = 1e-6,
    rtol: float = 1e-5
) -> None:
    """
    Validate concatenation of contexts representing distinct spin systems in a powder sample.
    
    Physical interpretation: Different molecular species/orientations coexisting in same sample.
    Mathematical operation: Direct sum (block-diagonal composition).
    """
    
    concat_context = mars.concat([context_1, context_2])
    full_system_vectors = direct_sum_batched(A=full_system_vectors_1, B=full_system_vectors_2)
    
    # 1. Dimensionality (direct sum: N_total = N₁ + N₂)
    dim1 = context_1.spin_system_dim
    dim2 = context_2.spin_system_dim
    dim_concat = concat_context.spin_system_dim
    assert dim_concat == dim1 + dim2, \
        f"Dimension mismatch: {dim1} ⊕ {dim2} = {dim_concat} (expected {dim1 + dim2})"
    
    # 2. Populations form block-diagonal structure
    pop1 = context_1.get_transformed_init_populations(full_system_vectors_1, normalize=False)
    pop2 = context_2.get_transformed_init_populations(full_system_vectors_2, normalize=False)
    
    
    pop_concat = concat_context.get_transformed_init_populations(full_system_vectors, normalize=False)
    
    if pop1 is not None and pop2 is not None and pop_concat is not None:
        expected_pop = torch.cat([pop1, pop2], dim=-1)
        assert torch.allclose(pop_concat, expected_pop, atol=atol, rtol=rtol), \
            "Populations don't concatenate correctly"
    
    # 3. Rate matrices form block-diagonal structure
    probs1 = context_1.get_transformed_free_probs(full_system_vectors_1)
    probs2 = context_2.get_transformed_free_probs(full_system_vectors_2)
    probs_concat = concat_context.get_transformed_free_probs(full_system_vectors)
    
    if probs1 is not None and probs2 is not None and probs_concat is not None:
        # Build expected block-diagonal matrix
        expected_probs = torch.zeros(
            probs1.shape[:-2] + (dim1 + dim2, dim1 + dim2),
            device=probs1.device, dtype=probs1.dtype
        )
        expected_probs[..., :dim1, :dim1] = probs1
        expected_probs[..., dim1:, dim1:] = probs2
        
        assert torch.allclose(probs_concat, expected_probs, atol=atol, rtol=rtol), \
            "Rate matrices don't form block-diagonal structure"
        
        # Each block should independently conserve probability
        col_sums1 = probs1.sum(-2)
        col_sums2 = probs2.sum(-2)
    
    # 4. Density matrices form block-diagonal structure
    dens1 = context_1.get_transformed_init_density(full_system_vectors_1)
    dens2 = context_2.get_transformed_init_density(full_system_vectors_2)
    dens_concat = concat_context.get_transformed_init_density(full_system_vectors)
    
    if dens1 is not None and dens2 is not None and dens_concat is not None:
        expected_dens = torch.zeros(
            dens1.shape[:-2] + (dim1 + dim2, dim1 + dim2),
            device=dens1.device, dtype=dens1.dtype
        )
        expected_dens[..., :dim1, :dim1] = dens1
        expected_dens[..., dim1:, dim1:] = dens2
        
        assert torch.allclose(dens_concat, expected_dens, atol=atol, rtol=rtol), \
            "Density matrices don't form block-diagonal structure"
        
        # Each block should have trace 1 (normalized states)

        trace1 = torch.diagonal(dens1, dim1=-2, dim2=-1).sum(-1)
        trace2 = torch.diagonal(dens2, dim1=-2, dim2=-1).sum(-1)
        
        assert torch.allclose(trace1, torch.ones_like(trace1), atol=1e-5, rtol=rtol), \
            "Block 1 density matrix trace ≠ 1"
        assert torch.allclose(trace2, torch.ones_like(trace2), atol=1e-5, rtol=rtol), \
            "Block 2 density matrix trace ≠ 1"
    
    print("✓ Concatenation test passed: block-diagonal structure for populations, rates, densities")

# 4. Create different type-context to for checking

In [7]:
def test_comprehensive_context_operations(dtype=torch.float64):
    sample_1, sample_2 = create_samples()
    vectors_1 = get_eigen_basis(sample_1, 0.1)
    vectors_2 = get_eigen_basis(sample_2, 0.1)
    
    for creator in context_creators:
        for ctx1, ctx2 in get_context_pairs(creator, sample_1, sample_2, use_same_sample=False):
            # Core operations
            check_multiplication_correctness(ctx1, ctx2, vectors_1, vectors_2)
            check_concatenation_correctness(ctx1, ctx2, vectors_1, vectors_2)
            check_density_transformation_correctness(ctx1, vectors_1)
    
    for creator in context_creators:
        for ctx1, ctx2 in get_context_pairs(creator, sample_1, sample_2, use_same_sample=True):
            check_sum_correctness(ctx1, ctx2, vectors_1)
    
    # Additional edge case: Context with explicit density matrix
    print("\nTesting explicit density matrix initialization...")
    rho0 = torch.tensor([[0.5, 0.1+0.05j, 0.0],
                         [0.1-0.05j, 0.3, 0.0],
                         [0.0, 0.0, 0.2]], dtype=torch.complex128)
    context_rho = Context(
        basis="zfs",
        sample=sample_1,
        init_density=rho0,
        free_probs=torch.tensor([[0.0, 1e5, 0.0], 
                                 [1e5, 0.0, 1e5], 
                                 [0.0, 1e5, 0.0]], dtype=dtype),
        dtype=dtype
    )
    check_density_transformation_correctness(context_rho, vectors_1)
    
    print("\n" + "=" * 60)
    print("✅ All context operation tests passed successfully!")

In [8]:
test_comprehensive_context_operations()

Multiplication test passed: dimensions, populations, densities, rates, superoperators
✓ Concatenation test passed: block-diagonal structure for populations, rates, densities
✓ Density transformation test passed: Hermiticity, trace, positivity, consistency
Multiplication test passed: dimensions, populations, densities, rates, superoperators
✓ Concatenation test passed: block-diagonal structure for populations, rates, densities
✓ Density transformation test passed: Hermiticity, trace, positivity, consistency
Multiplication test passed: dimensions, populations, densities, rates, superoperators
✓ Concatenation test passed: block-diagonal structure for populations, rates, densities
✓ Density transformation test passed: Hermiticity, trace, positivity, consistency
Multiplication test passed: dimensions, populations, densities, rates, superoperators
✓ Concatenation test passed: block-diagonal structure for populations, rates, densities
✓ Density transformation test passed: Hermiticity, trace, 