In [15]:
%load_ext autoreload
%autoreload 2

import sys
import os
import math
from importlib import reload

import torch

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))

import mars
from mars import spin_model, spectra_manager, constants, population, concat

import torch

from mars.population.transform import (
   basis_transformation,
   compute_liouville_basis_transformation,
   transform_superop_to_new_basis,
   Liouvilleator
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### 1. Simple Check: Hadamard-like Basis Change (2×2)

In [16]:
basis_old = torch.tensor([
   [1.0, 0.0],
   [0.0, 1.0]
], dtype=torch.complex64)

basis_new = torch.tensor([
   [1.0,  1.0],
   [1.0, -1.0]
], dtype=torch.complex64) / torch.sqrt(torch.tensor(2.0))

# 1. Get Hilbert-space transformation (for operators like ρ)
U = basis_transformation(basis_old, basis_new)
# Equivalent to: U = basis_new.conj().T @ basis_old

# 2. Get Liouville-space transformation (for vectorized ρ or superoperators)
T = compute_liouville_basis_transformation(basis_old, basis_new)

# 3. Transform a density matrix (Hilbert space)
rho_old = torch.tensor([[0.6, 0.1+0.2j],
                       [0.1-0.2j, 0.4]], dtype=torch.complex64)
rho_new = U @ rho_old @ U.conj().T

# 4. Transform via vectorization (Liouville space) – equivalent result
rho_old_vec = rho_old.flatten()          # Row-major: [0.6, 0.1+0.2j, 0.1-0.2j, 0.4]
rho_new_vec = T @ rho_old_vec
assert torch.allclose(rho_new_vec.reshape(2, 2), rho_new)

# 2. Advanced Validation with Generic Complex Unitary (3×3)

### 2.1  Prepare Bases and Transformations

In [17]:
torch.manual_seed(42)

<torch._C.Generator at 0x18a15c65110>

In [18]:
# ── 1. Generate a generic complex unitary without symmetries ──
# QR decomposition of random complex matrix ensures true unitarity
def random_unitary(n):
    X = torch.randn(n, n, dtype=torch.complex64) + 1j * torch.randn(n, n, dtype=torch.complex64)
    Q, R = torch.linalg.qr(X)
    # Fix phase to ensure true unitarity
    d = torch.diag(R)
    ph = d / d.abs()
    return Q * ph.unsqueeze(0)

basis_old = random_unitary(3)          # Computational basis (3-level system)
basis_new = random_unitary(3)                  # Generic complex basis

# ── 2. Compute transformations ──
U = basis_transformation(basis_old, basis_new)           # U = basis_new^† @ basis_old
T = compute_liouville_basis_transformation(basis_old, basis_new)  # T = kron(U, U.conj())

### 2.2.  Test Vectorized Operator Transformation

In [19]:
rho_old = torch.randn(3, 3, dtype=torch.complex64)
rho_old = (rho_old + rho_old.conj().T) / 2  # Ensure Hermiticity

# Direct Hilbert-space transformation
rho_new_direct = U @ rho_old @ U.conj().T

# Liouville-space transformation via vectorization
rho_old_vec = rho_old.flatten()  # Row-major ordering
rho_new_vec = T @ rho_old_vec
rho_new_via_liouville = rho_new_vec.reshape(3, 3)

# Critical check: both methods must agree
assert torch.allclose(rho_new_direct, rho_new_via_liouville, atol=1e-6)

### 2.3. Validate Hamiltonian Superoperator in Liouville Space

In [20]:
H_old = torch.randn(3, 3, dtype=torch.complex64)
H_old = (H_old + H_old.conj().T) / 2

# (a) Direct commutator in Hilbert space
commutator_direct = -1j * (H_old @ rho_old - rho_old @ H_old)

# (b) Liouville-space superoperator application
liouv = Liouvilleator()
L_H_old = liouv.hamiltonian_superop(H_old)  # Returns N²×N² superoperator for row-major vec
commutator_via_liouville = (L_H_old @ rho_old_vec).reshape(3, 3)

# Verify Liouville representation is correct
assert torch.allclose(commutator_direct, commutator_via_liouville, atol=1e-6)

### 2.4. Verify Superoperator Transforms Correctly Under Basis Change

In [21]:
# (c) Transform Hamiltonian to new basis
H_new = U @ H_old @ U.conj().T

# (d) Transform superoperator via similarity transformation
L_H_new_transformed = T @ L_H_old @ T.conj().transpose(-2, -1)

# (e) Construct superoperator directly from transformed Hamiltonian
L_H_new_direct = liouv.hamiltonian_superop(H_new)

# Critical check: transformed superoperator must match direct construction
assert torch.allclose(L_H_new_transformed, L_H_new_direct, atol=1e-6)

# (f) Final consistency check: dynamics must be basis-invariant
rho_new = U @ rho_old @ U.conj().T
d_rho_new_direct = -1j * (H_new @ rho_new - rho_new @ H_new)
d_rho_new_via_transform = (L_H_new_transformed @ (T @ rho_old_vec)).reshape(3, 3)

assert torch.allclose(d_rho_new_direct, d_rho_new_via_transform, atol=1e-6)

### 2.5. Verify Vectorization transformations

In [22]:
from mars.population.transform import (
   reshape_vectorized_kronecker_to_tensor_product,
   reshape_vectorized_kronecker_to_tensor_product,
   reshape_superoperator_kronecker_to_tensor_basis,
   reshape_superoperator_tensor_to_kronecker_basis
)


def test_two_spin_transformations():
    """Test with two spins (dims = [2, 2])"""
    print("=" * 60)
    print("Test 1: Two-spin system [2, 2]")
    print("=" * 60)
    
    batch_size = 3
    dims = [2, 2]
    
    # Create random density matrices for each subsystem
    rho1 = torch.randn(batch_size, 2, 2, dtype=torch.complex64)
    rho2 = torch.randn(batch_size, 2, 2, dtype=torch.complex64)
    
    # Make them Hermitian (not required for test, but more physical)
    rho1 = (rho1 + rho1.conj().transpose(-1, -2)) / 2
    rho2 = (rho2 + rho2.conj().transpose(-1, -2)) / 2
    
    # Compute Kronecker product manually
    rho_kron = torch.zeros(batch_size, 4, 4, dtype=torch.complex64)
    for b in range(batch_size):
        rho_kron[b] = torch.kron(rho1[b], rho2[b])
    
    # Vectorize (row-stacked)
    vec_rho_kron = rho_kron.reshape(batch_size, -1)
    
    print(f"Original vec(rho_1 ⊗ rho_2) shape: {vec_rho_kron.shape}")
    
    # Test forward transformation
    vec_tensor = reshape_vectorized_kronecker_to_tensor_product(vec_rho_kron, dims)
    print(f"Transformed to vec(rho_1) ⊗ vec(rho_2) shape: {vec_tensor.shape}")
    
    # Manually compute vec(rho1) ⊗ vec(rho2)
    vec_rho1 = rho1.reshape(batch_size, -1)
    vec_rho2 = rho2.reshape(batch_size, -1)
    vec_tensor_expected = torch.zeros(batch_size, 16, dtype=torch.complex64)
    for b in range(batch_size):
        vec_tensor_expected[b] = torch.kron(vec_rho1[b], vec_rho2[b])
    
    error_forward = torch.max(torch.abs(vec_tensor - vec_tensor_expected))
    print(f"Forward transformation error: {error_forward.item():.2e}")
    
    # Test backward transformation
    vec_rho_kron_reconstructed = reshape_vectorized_tensor_product_to_kronecker(vec_tensor, dims)
    error_backward = torch.max(torch.abs(vec_rho_kron - vec_rho_kron_reconstructed))
    print(f"Backward transformation error: {error_backward.item():.2e}")
    
    # Test round-trip
    print(f"Round-trip successful: {error_backward < 1e-5}")
    print()


def test_three_subsystems():
    """Test with three subsystems of different dimensions"""
    print("=" * 60)
    print("Test 2: Three subsystems [2, 3, 2]")
    print("=" * 60)
    
    batch_size = 2
    dims = [2, 3, 2]
    
    # Create random density matrices
    rho1 = torch.randn(batch_size, 2, 2, dtype=torch.complex64)
    rho2 = torch.randn(batch_size, 3, 3, dtype=torch.complex64)
    rho3 = torch.randn(batch_size, 2, 2, dtype=torch.complex64)
    
    # Compute triple Kronecker product
    rho_kron = torch.zeros(batch_size, 12, 12, dtype=torch.complex64)
    for b in range(batch_size):
        temp = torch.kron(rho1[b], rho2[b])
        rho_kron[b] = torch.kron(temp, rho3[b])
    
    vec_rho_kron = rho_kron.reshape(batch_size, -1)
    
    print(f"Original vec(rho_1 ⊗ rho_2 ⊗ rho_3) shape: {vec_rho_kron.shape}")
    
    # Transform to tensor product
    vec_tensor = reshape_vectorized_kronecker_to_tensor_product(vec_rho_kron, dims)
    print(f"Transformed to vec(rho_1) ⊗ vec(rho_2) ⊗ vec(rho_3) shape: {vec_tensor.shape}")
    
    # Manually compute expected result
    vec_rho1 = rho1.reshape(batch_size, -1)
    vec_rho2 = rho2.reshape(batch_size, -1)
    vec_rho3 = rho3.reshape(batch_size, -1)
    vec_tensor_expected = torch.zeros(batch_size, 4 * 9 * 4, dtype=torch.complex64)
    for b in range(batch_size):
        temp = torch.kron(vec_rho1[b], vec_rho2[b])
        vec_tensor_expected[b] = torch.kron(temp, vec_rho3[b])
    
    error_forward = torch.max(torch.abs(vec_tensor - vec_tensor_expected))
    print(f"Forward transformation error: {error_forward.item():.2e}")
    
    # Test inverse
    vec_rho_kron_reconstructed = reshape_vectorized_tensor_product_to_kronecker(vec_tensor, dims)
    error_backward = torch.max(torch.abs(vec_rho_kron - vec_rho_kron_reconstructed))
    print(f"Backward transformation error: {error_backward.item():.2e}")
    print(f"Round-trip successful: {error_backward < 1e-5}")
    print()


def test_superoperator_transformations():
    """Test superoperator transformations"""
    print("=" * 60)
    print("Test 3: Superoperator transformations [2, 2]")
    print("=" * 60)
    
    batch_size = 2
    dims = [2, 2]
    total_dim = 4
    
    # Create a random superoperator in Kronecker basis
    L_kron = torch.randn(batch_size, 16, 16, dtype=torch.complex64)
    
    print(f"Original superoperator shape: {L_kron.shape}")
    
    # Transform to tensor product basis
    L_tensor = reshape_superoperator_kronecker_to_tensor_basis(L_kron, dims)
    print(f"Transformed superoperator shape: {L_tensor.shape}")
    
    # Transform back
    L_kron_reconstructed = reshape_superoperator_tensor_to_kronecker_basis(L_tensor, dims)
    
    error = torch.max(torch.abs(L_kron - L_kron_reconstructed))
    print(f"Round-trip error: {error.item():.2e}")
    print(f"Round-trip successful: {error < 1e-4}")
    print()


def test_superoperator_action():
    """Test that superoperator acts correctly on density matrices"""
    print("=" * 60)
    print("Test 4: Superoperator action consistency [2, 2]")
    print("=" * 60)
    
    batch_size = 2
    dims = [2, 2]
    
    # Create random density matrix in Kronecker form
    rho = torch.randn(batch_size, 4, 4, dtype=torch.complex64)
    rho = (rho + rho.conj().transpose(-1, -2)) / 2
    vec_rho_kron = rho.reshape(batch_size, -1)
    
    # Create a simple superoperator (e.g., dephasing)
    L_kron = torch.randn(batch_size, 16, 16, dtype=torch.complex64)
    
    # Apply in Kronecker basis
    vec_out_kron = torch.bmm(L_kron, vec_rho_kron.unsqueeze(-1)).squeeze(-1)
    
    # Convert everything to tensor product basis
    vec_rho_tensor = reshape_vectorized_kronecker_to_tensor_product(vec_rho_kron, dims)
    L_tensor = reshape_superoperator_kronecker_to_tensor_basis(L_kron, dims)
    
    # Apply in tensor product basis
    vec_out_tensor = torch.bmm(L_tensor, vec_rho_tensor.unsqueeze(-1)).squeeze(-1)
    
    # Convert output back to Kronecker basis
    vec_out_kron_from_tensor = reshape_vectorized_tensor_product_to_kronecker(vec_out_tensor, dims)
    
    error = torch.max(torch.abs(vec_out_kron - vec_out_kron_from_tensor))
    print(f"Consistency error: {error.item():.2e}")
    print(f"Superoperator action consistent: {error < 1e-4}")
    print()


def test_single_spin():
    """Test with single spin (edge case)"""
    print("=" * 60)
    print("Test 5: Single spin [4]")
    print("=" * 60)
    
    batch_size = 3
    dims = [4]
    
    rho = torch.randn(batch_size, 4, 4, dtype=torch.complex64)
    vec_rho = rho.reshape(batch_size, -1)
    
    print(f"Original shape: {vec_rho.shape}")
    
    # For single system, transformations should be identity
    vec_tensor = reshape_vectorized_kronecker_to_tensor_product(vec_rho, dims)
    vec_reconstructed = reshape_vectorized_tensor_product_to_kronecker(vec_tensor, dims)
    
    error = torch.max(torch.abs(vec_rho - vec_reconstructed))
    print(f"Round-trip error: {error.item():.2e}")
    print(f"Single spin test passed: {error < 1e-5}")
    print()


def test_reshape_flatten_consistency():
    """Test using reshape and flatten operations"""
    print("=" * 60)
    print("Test 6: Reshape/Flatten consistency [2, 3]")
    print("=" * 60)
    
    batch_size = 4
    dims = [2, 3]
    
    # Create density matrices using different reshaping approaches
    rho1 = torch.randn(batch_size, 2, 2)
    rho2 = torch.randn(batch_size, 3, 3)
    
    # Method 1: Using flatten
    vec1_flat = rho1.flatten(start_dim=1)
    vec2_flat = rho2.flatten(start_dim=1)
    
    # Method 2: Using reshape
    vec1_reshape = rho1.reshape(batch_size, -1)
    vec2_reshape = rho2.reshape(batch_size, -1)
    
    print(f"Flatten matches reshape for rho1: {torch.allclose(vec1_flat, vec1_reshape)}")
    print(f"Flatten matches reshape for rho2: {torch.allclose(vec2_flat, vec2_reshape)}")
    
    # Create Kronecker product
    rho_kron = torch.zeros(batch_size, 6, 6)
    for b in range(batch_size):
        rho_kron[b] = torch.kron(rho1[b], rho2[b])
    
    # Test with both flatten and reshape
    vec_kron_flat = rho_kron.flatten(start_dim=1)
    vec_kron_reshape = rho_kron.reshape(batch_size, -1)
    
    print(f"Flatten matches reshape for Kronecker: {torch.allclose(vec_kron_flat, vec_kron_reshape)}")
    
    # Transform using our function
    vec_tensor = reshape_vectorized_kronecker_to_tensor_product(vec_kron_flat, dims)
    
    # Verify we can reshape back
    reshaped = vec_tensor.reshape(batch_size, 4, 9)
    flattened = reshaped.flatten(start_dim=1)
    
    print(f"Can reshape and flatten result: {torch.allclose(vec_tensor, flattened)}")
    print()


if __name__ == "__main__":
    print("\n")
    print("╔" + "═" * 58 + "╗")
    print("║" + " " * 10 + "QUANTUM DENSITY MATRIX TRANSFORMATIONS" + " " * 10 + "║")
    print("║" + " " * 22 + "TEST SUITE" + " " * 24 + "║")
    print("╚" + "═" * 58 + "╝")
    print("\n")
    
    # Run all tests
    test_two_spin_transformations()
    test_three_subsystems()
    test_superoperator_transformations()
    test_superoperator_action()
    test_single_spin()
    test_reshape_flatten_consistency()
    
    print("=" * 60)
    print("ALL TESTS COMPLETED")
    print("=" * 60)



╔══════════════════════════════════════════════════════════╗
║          QUANTUM DENSITY MATRIX TRANSFORMATIONS          ║
║                      TEST SUITE                        ║
╚══════════════════════════════════════════════════════════╝


Test 1: Two-spin system [2, 2]
Original vec(rho_1 ⊗ rho_2) shape: torch.Size([3, 16])
Transformed to vec(rho_1) ⊗ vec(rho_2) shape: torch.Size([3, 16])
Forward transformation error: 0.00e+00
Backward transformation error: 0.00e+00
Round-trip successful: True

Test 2: Three subsystems [2, 3, 2]
Original vec(rho_1 ⊗ rho_2 ⊗ rho_3) shape: torch.Size([2, 144])
Transformed to vec(rho_1) ⊗ vec(rho_2) ⊗ vec(rho_3) shape: torch.Size([2, 144])
Forward transformation error: 0.00e+00
Backward transformation error: 0.00e+00
Round-trip successful: True

Test 3: Superoperator transformations [2, 2]
Original superoperator shape: torch.Size([2, 16, 16])
Transformed superoperator shape: torch.Size([2, 16, 16])
Round-trip error: 0.00e+00
Round-trip successful: T