In [None]:
"""
Full Benchmark: Original vs Optimized Swap Operations
Sr(Ti,Fe)(O,VO)3 Perovskite System - Functional Style (No Classes)

비교 항목:
1. swap_by_idx: 단일 스왑
2. sample_sublattice_swap: 단일 sublattice 스왑
3. apply_n_swaps: N번 연속 스왑 (scan)
4. sample_sublattice_swap_beam: 빔서치

Benchmark:
- batch_size = 10000
- n_swaps = 10000
"""

import jax
import jax.numpy as jnp
from jax import random, lax
from functools import partial
from typing import Tuple, Optional, NamedTuple
import time

print("="*60)
print("JAX Devices:", jax.devices())
print("="*60)


# =============================================================================
# Data Structures
# =============================================================================

class SwapResult(NamedTuple):
    swapped: jnp.ndarray
    indices: jnp.ndarray

class BeamResult(NamedTuple):
    swapped: jnp.ndarray
    indices: jnp.ndarray
    log_probs: jnp.ndarray

ATOM_TYPES = {"Sr": 0, "Ti": 1, "Fe": 2, "O": 3, "VO": 4}


# =============================================================================
# Structure Generator
# =============================================================================

@partial(jax.jit, static_argnums=(1, 2, 3, 4, 5, 6))
def generate_structures(
    key: random.PRNGKey,
    batch_size: int,
    n_Sr: int, n_Ti: int, n_Fe: int, n_O: int, n_VO: int,
) -> jnp.ndarray:
    """Generate random perovskite structures."""
    n_B = n_Ti + n_Fe
    n_Osite = n_O + n_VO

    key_b, key_o = random.split(key)

    b_template = jnp.concatenate([
        jnp.ones(n_Ti, dtype=jnp.int32) * ATOM_TYPES["Ti"],
        jnp.ones(n_Fe, dtype=jnp.int32) * ATOM_TYPES["Fe"],
    ])
    o_template = jnp.concatenate([
        jnp.ones(n_O, dtype=jnp.int32) * ATOM_TYPES["O"],
        jnp.ones(n_VO, dtype=jnp.int32) * ATOM_TYPES["VO"],
    ])

    # Random permutation via argsort
    noise_b = random.uniform(key_b, (batch_size, n_B))
    noise_o = random.uniform(key_o, (batch_size, n_Osite))
    b_configs = b_template[jnp.argsort(noise_b, axis=-1)]
    o_configs = o_template[jnp.argsort(noise_o, axis=-1)]

    a_configs = jnp.ones((batch_size, n_Sr), dtype=jnp.int32) * ATOM_TYPES["Sr"]

    return jnp.concatenate([a_configs, b_configs, o_configs], axis=-1)


# =============================================================================
# ORIGINAL: 네 코드 그대로
# =============================================================================

@jax.jit
def orig_swap_by_idx(x: jnp.ndarray, idx: jnp.ndarray) -> jnp.ndarray:
    """Original: .at[].set() 2회"""
    batch_size = x.shape[0]
    batch_idx = jnp.arange(batch_size)
    idx_a, idx_b = idx[:, 0], idx[:, 1]
    val_a = x[batch_idx, idx_a]
    val_b = x[batch_idx, idx_b]
    x_swapped = x.at[batch_idx, idx_a].set(val_b)
    x_swapped = x_swapped.at[batch_idx, idx_b].set(val_a)
    return x_swapped


@partial(jax.jit, static_argnums=(2, 3))
def orig_sample_sublattice_swap(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: Optional[jnp.ndarray] = None,
) -> SwapResult:
    """Original single sublattice swap"""
    sub_idx = jnp.array(sublattice_indices)
    batch_size, N = atom_types.shape
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    sub_scores = jnp.zeros((batch_size, M)) if scores is None else scores[:, sub_idx]

    key_a, key_b = random.split(key)
    gumbel_a = random.gumbel(key_a, (batch_size, M))
    gumbel_b = random.gumbel(key_b, (batch_size, M))

    score_a = jnp.where(is_a, sub_scores + gumbel_a, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores + gumbel_b, -jnp.inf)

    local_a = jnp.argmax(score_a, axis=-1)
    local_b = jnp.argmax(score_b, axis=-1)

    global_a = sub_idx[local_a]
    global_b = sub_idx[local_b]
    indices = jnp.stack([global_a, global_b], axis=-1)

    swapped = orig_swap_by_idx(atom_types, indices)
    return SwapResult(swapped=swapped, indices=indices)


@partial(jax.jit, static_argnums=(3, 4, 5, 6))
def orig_apply_n_swaps(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    scores: Optional[jnp.ndarray],
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    n_swaps: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Original N-step swap with lax.scan"""
    keys = random.split(key, n_swaps)

    def scan_fn(carry, key_i):
        x = carry
        result = orig_sample_sublattice_swap(
            key_i, x, sublattice_indices, type_a, type_b, scores
        )
        return result.swapped, result.indices

    final, all_indices = lax.scan(scan_fn, atom_types, keys)
    return final, all_indices


@partial(jax.jit, static_argnums=(2, 3, 5))
def orig_beam_search(
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: jnp.ndarray,
    beam_size: int = 4,
) -> BeamResult:
    """Original beam search with nested vmap"""
    sub_idx = jnp.array(sublattice_indices)
    batch_size, N = atom_types.shape
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    sub_scores = scores[:, sub_idx]

    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    score_a = jnp.where(is_a, sub_scores, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores, -jnp.inf)

    top_scores_a, top_local_a = lax.top_k(score_a, beam_size)
    top_scores_b, top_local_b = lax.top_k(score_b, beam_size)

    pair_scores = top_scores_a[:, :, None] + top_scores_b[:, None, :]
    pair_scores_flat = pair_scores.reshape(batch_size, -1)
    top_pair_scores, top_pair_idx = lax.top_k(pair_scores_flat, beam_size)

    idx_a_beam = top_pair_idx // beam_size
    idx_b_beam = top_pair_idx % beam_size

    batch_idx = jnp.arange(batch_size)[:, None]
    local_a_beam = top_local_a[batch_idx, idx_a_beam]
    local_b_beam = top_local_b[batch_idx, idx_b_beam]

    global_a_beam = sub_idx[local_a_beam]
    global_b_beam = sub_idx[local_b_beam]
    indices_candidates = jnp.stack([global_a_beam, global_b_beam], axis=-1)

    # Nested vmap (네가 만든 방식)
    def swap_single_beam(x, idx):
        return orig_swap_by_idx(x[None, :], idx[None, :])[0]
    def swap_all_beams(x, indices):
        return jax.vmap(lambda idx: swap_single_beam(x, idx))(indices)
    swapped_candidates = jax.vmap(swap_all_beams)(atom_types, indices_candidates)

    log_prob_a = jax.nn.log_softmax(score_a, axis=-1)
    log_prob_b = jax.nn.log_softmax(score_b, axis=-1)
    lp_a = log_prob_a[batch_idx, local_a_beam]
    lp_b = log_prob_b[batch_idx, local_b_beam]
    log_probs = lp_a + lp_b

    return BeamResult(swapped=swapped_candidates, indices=indices_candidates, log_probs=log_probs)


# =============================================================================
# OPTIMIZED: 내가 제안한 최적화
# =============================================================================

@jax.jit
def opt_swap_by_idx(x: jnp.ndarray, idx: jnp.ndarray) -> jnp.ndarray:
    """Optimized: single scatter"""
    batch_size, N = x.shape
    batch_idx = jnp.arange(batch_size)
    idx_a, idx_b = idx[:, 0], idx[:, 1]
    val_a = x[batch_idx, idx_a]
    val_b = x[batch_idx, idx_b]

    # Single scatter
    scatter_indices = jnp.stack([
        jnp.stack([batch_idx, idx_a], axis=1),
        jnp.stack([batch_idx, idx_b], axis=1)
    ], axis=0).reshape(-1, 2)
    scatter_values = jnp.concatenate([val_b, val_a])

    x_flat = x.reshape(-1)
    flat_indices = scatter_indices[:, 0] * N + scatter_indices[:, 1]
    x_swapped = x_flat.at[flat_indices].set(scatter_values)
    return x_swapped.reshape(batch_size, N)


@jax.jit
def opt_swap_by_idx_v2(x: jnp.ndarray, idx: jnp.ndarray) -> jnp.ndarray:
    """Optimized v2: permutation (TPU-friendly)"""
    batch_size, N = x.shape
    batch_idx = jnp.arange(batch_size)
    idx_a, idx_b = idx[:, 0], idx[:, 1]

    perm = jnp.broadcast_to(jnp.arange(N), (batch_size, N))
    perm = perm.at[batch_idx, idx_a].set(idx_b)
    perm = perm.at[batch_idx, idx_b].set(idx_a)
    return jnp.take_along_axis(x, perm, axis=1)


@partial(jax.jit, static_argnums=(2, 3))
def opt_sample_sublattice_swap(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: Optional[jnp.ndarray] = None,
) -> SwapResult:
    """Optimized single swap using opt_swap_by_idx"""
    sub_idx = jnp.array(sublattice_indices)
    batch_size, N = atom_types.shape
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    sub_scores = jnp.zeros((batch_size, M)) if scores is None else scores[:, sub_idx]

    key_a, key_b = random.split(key)
    gumbel_a = random.gumbel(key_a, (batch_size, M))
    gumbel_b = random.gumbel(key_b, (batch_size, M))

    score_a = jnp.where(is_a, sub_scores + gumbel_a, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores + gumbel_b, -jnp.inf)

    local_a = jnp.argmax(score_a, axis=-1)
    local_b = jnp.argmax(score_b, axis=-1)

    indices = jnp.stack([sub_idx[local_a], sub_idx[local_b]], axis=-1)
    swapped = opt_swap_by_idx(atom_types, indices)
    return SwapResult(swapped=swapped, indices=indices)


@partial(jax.jit, static_argnums=(3, 4, 5, 6))
def opt_apply_n_swaps(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    scores: Optional[jnp.ndarray],
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    n_swaps: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Optimized N-step swap"""
    keys = random.split(key, n_swaps)

    def scan_fn(carry, key_i):
        x = carry
        result = opt_sample_sublattice_swap(
            key_i, x, sublattice_indices, type_a, type_b, scores
        )
        return result.swapped, result.indices

    final, all_indices = lax.scan(scan_fn, atom_types, keys)
    return final, all_indices


@partial(jax.jit, static_argnums=(2, 3, 5))
def opt_beam_search(
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: jnp.ndarray,
    beam_size: int = 4,
) -> BeamResult:
    """Optimized beam search: flatten → single swap → reshape"""
    sub_idx = jnp.array(sublattice_indices)
    batch_size, N = atom_types.shape
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    sub_scores = scores[:, sub_idx]

    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    score_a = jnp.where(is_a, sub_scores, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores, -jnp.inf)

    top_scores_a, top_local_a = lax.top_k(score_a, beam_size)
    top_scores_b, top_local_b = lax.top_k(score_b, beam_size)

    pair_scores = top_scores_a[:, :, None] + top_scores_b[:, None, :]
    pair_scores_flat = pair_scores.reshape(batch_size, -1)
    top_pair_scores, top_pair_idx = lax.top_k(pair_scores_flat, beam_size)

    idx_a_beam = top_pair_idx // beam_size
    idx_b_beam = top_pair_idx % beam_size

    batch_idx = jnp.arange(batch_size)[:, None]
    local_a_beam = top_local_a[batch_idx, idx_a_beam]
    local_b_beam = top_local_b[batch_idx, idx_b_beam]

    global_a_beam = sub_idx[local_a_beam]
    global_b_beam = sub_idx[local_b_beam]
    indices_candidates = jnp.stack([global_a_beam, global_b_beam], axis=-1)

    # KEY OPTIMIZATION: flatten batch*beam → single swap call → reshape
    expanded_types = jnp.broadcast_to(
        atom_types[:, None, :], (batch_size, beam_size, N)
    ).reshape(batch_size * beam_size, N)

    flat_indices = indices_candidates.reshape(batch_size * beam_size, 2)
    swapped_flat = opt_swap_by_idx(expanded_types, flat_indices)
    swapped_candidates = swapped_flat.reshape(batch_size, beam_size, N)

    log_prob_a = jax.nn.log_softmax(score_a, axis=-1)
    log_prob_b = jax.nn.log_softmax(score_b, axis=-1)
    lp_a = log_prob_a[batch_idx, local_a_beam]
    lp_b = log_prob_b[batch_idx, local_b_beam]
    log_probs = lp_a + lp_b

    return BeamResult(swapped=swapped_candidates, indices=indices_candidates, log_probs=log_probs)


# =============================================================================
# BENCHMARK
# =============================================================================

def benchmark(name, fn, *args, n_warmup=3, n_trials=10):
    """Run benchmark with warmup"""

    def wait_ready(out):
        """Handle different return types"""
        if hasattr(out, 'swapped'):  # SwapResult or BeamResult
            out.swapped.block_until_ready()
        elif isinstance(out, tuple):
            out[0].block_until_ready()
        elif hasattr(out, 'block_until_ready'):  # raw array
            out.block_until_ready()

    # Warmup
    for _ in range(n_warmup):
        out = fn(*args) if args else fn()
        wait_ready(out)

    # Benchmark
    times = []
    for _ in range(n_trials):
        start = time.perf_counter()
        out = fn(*args) if args else fn()
        wait_ready(out)
        times.append(time.perf_counter() - start)

    avg = sum(times) / len(times)
    std = (sum((t - avg)**2 for t in times) / len(times)) ** 0.5
    print(f"{name:40s}: {avg*1000:10.2f} ms ± {std*1000:.2f}")
    return avg


def run_all_benchmarks():
    """Run full benchmark suite"""

    # Setup
    BATCH_SIZE = 10000
    N_SWAPS = 10000
    BEAM_SIZE = 4

    composition = {"Sr": 32, "Ti": 8, "Fe": 24, "O": 84, "VO": 12}
    N = sum(composition.values())

    n_B = composition["Ti"] + composition["Fe"]
    n_O_total = composition["O"] + composition["VO"]

    # Python int로 명시적 변환 (JAX Array가 들어가면 hash 에러)
    sr_end = int(composition["Sr"])
    b_end = int(composition["Sr"] + n_B)
    b_site_idx = tuple(range(sr_end, b_end))
    o_site_idx = tuple(range(b_end, int(N)))

    print(f"\n{'='*60}")
    print(f"BENCHMARK SETUP")
    print(f"{'='*60}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"N swaps: {N_SWAPS}")
    print(f"Atoms per structure: {N}")
    print(f"B-site (Ti↔Fe): {len(b_site_idx)} positions")
    print(f"O-site (O↔VO): {len(o_site_idx)} positions")
    print(f"Beam size: {BEAM_SIZE}")

    # Generate structures
    key = random.PRNGKey(42)
    key, subkey = random.split(key)

    structures = generate_structures(
        subkey, BATCH_SIZE,
        composition["Sr"], composition["Ti"], composition["Fe"],
        composition["O"], composition["VO"]
    )
    structures.block_until_ready()
    print(f"\nGenerated {BATCH_SIZE} structures: {structures.shape}")

    # Random scores
    key, subkey = random.split(key)
    scores = random.normal(subkey, structures.shape)

    results = {}

    # =========================================================================
    # Test 1: Single swap_by_idx
    # =========================================================================
    print(f"\n{'='*60}")
    print("TEST 1: swap_by_idx (단일 스왑)")
    print(f"{'='*60}")

    key, subkey = random.split(key)
    idx = random.randint(subkey, (BATCH_SIZE, 2), 0, N)

    results["orig_swap"] = benchmark("Original (2x .at[].set())", orig_swap_by_idx, structures, idx)
    results["opt_swap"] = benchmark("Optimized (single scatter)", opt_swap_by_idx, structures, idx)
    results["opt_swap_v2"] = benchmark("Optimized v2 (permutation)", opt_swap_by_idx_v2, structures, idx)

    # =========================================================================
    # Test 2: Single sublattice swap
    # =========================================================================
    print(f"\n{'='*60}")
    print("TEST 2: sample_sublattice_swap (단일 sublattice 스왑)")
    print(f"{'='*60}")

    key, subkey = random.split(key)
    results["orig_sublattice"] = benchmark(
        "Original",
        lambda: orig_sample_sublattice_swap(subkey, structures, b_site_idx, 1, 2, scores)
    )
    results["opt_sublattice"] = benchmark(
        "Optimized",
        lambda: opt_sample_sublattice_swap(subkey, structures, b_site_idx, 1, 2, scores)
    )

    # =========================================================================
    # Test 3: N swaps (핵심 테스트)
    # =========================================================================
    print(f"\n{'='*60}")
    print(f"TEST 3: apply_n_swaps ({N_SWAPS}번 연속 스왑) ⭐")
    print(f"{'='*60}")

    key, subkey = random.split(key)
    results["orig_n_swaps"] = benchmark(
        "Original",
        orig_apply_n_swaps, subkey, structures, scores, b_site_idx, 1, 2, N_SWAPS,
        n_warmup=2, n_trials=5
    )

    key, subkey = random.split(key)
    results["opt_n_swaps"] = benchmark(
        "Optimized",
        opt_apply_n_swaps, subkey, structures, scores, b_site_idx, 1, 2, N_SWAPS,
        n_warmup=2, n_trials=5
    )

    # =========================================================================
    # Test 4: Beam search
    # =========================================================================
    print(f"\n{'='*60}")
    print(f"TEST 4: beam_search (beam_size={BEAM_SIZE})")
    print(f"{'='*60}")

    results["orig_beam"] = benchmark(
        "Original (nested vmap)",
        orig_beam_search, structures, b_site_idx, 1, 2, scores, BEAM_SIZE
    )
    results["opt_beam"] = benchmark(
        "Optimized (flatten→swap→reshape)",
        opt_beam_search, structures, b_site_idx, 1, 2, scores, BEAM_SIZE
    )

    # =========================================================================
    # Summary
    # =========================================================================
    print(f"\n{'='*60}")
    print("SPEEDUP SUMMARY")
    print(f"{'='*60}")

    comparisons = [
        ("swap_by_idx", "orig_swap", "opt_swap"),
        ("swap_by_idx (v2)", "orig_swap", "opt_swap_v2"),
        ("sublattice_swap", "orig_sublattice", "opt_sublattice"),
        (f"apply_n_swaps ({N_SWAPS}x)", "orig_n_swaps", "opt_n_swaps"),
        ("beam_search", "orig_beam", "opt_beam"),
    ]

    for name, orig_key, opt_key in comparisons:
        speedup = results[orig_key] / results[opt_key]
        print(f"{name:30s}: {speedup:.2f}x speedup")

    return results


def verify_correctness():
    """Verify optimized versions match original"""
    print("\n" + "="*60)
    print("CORRECTNESS CHECK")
    print("="*60)

    key = random.PRNGKey(123)
    batch_size = 100

    composition = {"Sr": 32, "Ti": 8, "Fe": 24, "O": 84, "VO": 12}
    N = sum(composition.values())
    n_B = composition["Ti"] + composition["Fe"]
    # Python int로 명시적 변환
    sr_end = int(composition["Sr"])
    b_end = int(sr_end + n_B)
    b_site_idx = tuple(range(sr_end, b_end))

    key, subkey = random.split(key)
    structures = generate_structures(
        subkey, batch_size,
        composition["Sr"], composition["Ti"], composition["Fe"],
        composition["O"], composition["VO"]
    )

    key, subkey = random.split(key)
    scores = random.normal(subkey, structures.shape)

    # Test swap_by_idx
    key, subkey = random.split(key)
    idx = random.randint(subkey, (batch_size, 2), 0, N)

    out_orig = orig_swap_by_idx(structures, idx)
    out_opt = opt_swap_by_idx(structures, idx)
    out_opt_v2 = opt_swap_by_idx_v2(structures, idx)

    assert jnp.allclose(out_orig, out_opt), "swap_by_idx mismatch!"
    assert jnp.allclose(out_orig, out_opt_v2), "swap_by_idx_v2 mismatch!"
    print("✓ swap_by_idx: all versions match")

    # Test sublattice swap
    test_key = random.PRNGKey(999)

    res_orig = orig_sample_sublattice_swap(test_key, structures, b_site_idx, 1, 2, scores)
    res_opt = opt_sample_sublattice_swap(test_key, structures, b_site_idx, 1, 2, scores)

    assert jnp.allclose(res_orig.indices, res_opt.indices), "sublattice indices mismatch!"
    assert jnp.allclose(res_orig.swapped, res_opt.swapped), "sublattice swapped mismatch!"
    print("✓ sample_sublattice_swap: all versions match")

    # Test beam search
    beam_orig = orig_beam_search(structures, b_site_idx, 1, 2, scores, 4)
    beam_opt = opt_beam_search(structures, b_site_idx, 1, 2, scores, 4)

    assert jnp.allclose(beam_orig.indices, beam_opt.indices), "beam indices mismatch!"
    assert jnp.allclose(beam_orig.swapped, beam_opt.swapped), "beam swapped mismatch!"
    assert jnp.allclose(beam_orig.log_probs, beam_opt.log_probs), "beam log_probs mismatch!"
    print("✓ beam_search: all versions match")

    print("\n✅ All correctness checks passed!")


# =============================================================================
# MAIN
# =============================================================================

if __name__ == "__main__":
    verify_correctness()
    run_all_benchmarks()

JAX Devices: [CudaDevice(id=0)]

CORRECTNESS CHECK
✓ swap_by_idx: all versions match
✓ sample_sublattice_swap: all versions match
✓ beam_search: all versions match

✅ All correctness checks passed!

BENCHMARK SETUP
Batch size: 10000
N swaps: 10000
Atoms per structure: 160
B-site (Ti↔Fe): 32 positions
O-site (O↔VO): 96 positions
Beam size: 4

Generated 10000 structures: (10000, 160)

TEST 1: swap_by_idx (단일 스왑)
Original (2x .at[].set())               :       0.11 ms ± 0.02
Optimized (single scatter)              :       0.10 ms ± 0.03
Optimized v2 (permutation)              :       0.11 ms ± 0.01

TEST 2: sample_sublattice_swap (단일 sublattice 스왑)
Original                                :       0.16 ms ± 0.02
Optimized                               :       0.17 ms ± 0.01

TEST 3: apply_n_swaps (10000번 연속 스왑) ⭐
Original                                :     509.10 ms ± 0.92
Optimized                               :     469.10 ms ± 1.07

TEST 4: beam_search (beam_size=4)
Original (nested vm

In [None]:
"""
Sublattice-Constrained Swap Operations (JAX)
Sr(Ti,Fe)(O,VO)3 Perovskite System

Original Implementation - Proven Fastest
"""

import jax
import jax.numpy as jnp
from jax import random, lax
from functools import partial
from typing import Tuple, Optional, NamedTuple
import time

print("="*60)
print("JAX Devices:", jax.devices())
print("="*60)


# =============================================================================
# Data Structures
# =============================================================================

class SwapResult(NamedTuple):
    swapped: jnp.ndarray
    indices: jnp.ndarray

class BeamResult(NamedTuple):
    swapped: jnp.ndarray
    indices: jnp.ndarray
    log_probs: jnp.ndarray

ATOM_TYPES = {"Sr": 0, "Ti": 1, "Fe": 2, "O": 3, "VO": 4}


# =============================================================================
# Structure Generator
# =============================================================================

@partial(jax.jit, static_argnums=(1, 2, 3, 4, 5, 6))
def generate_structures(
    key: random.PRNGKey,
    batch_size: int,
    n_Sr: int, n_Ti: int, n_Fe: int, n_O: int, n_VO: int,
) -> jnp.ndarray:
    """Generate random perovskite structures."""
    n_B = n_Ti + n_Fe
    n_Osite = n_O + n_VO

    key_b, key_o = random.split(key)

    b_template = jnp.concatenate([
        jnp.ones(n_Ti, dtype=jnp.int32) * ATOM_TYPES["Ti"],
        jnp.ones(n_Fe, dtype=jnp.int32) * ATOM_TYPES["Fe"],
    ])
    o_template = jnp.concatenate([
        jnp.ones(n_O, dtype=jnp.int32) * ATOM_TYPES["O"],
        jnp.ones(n_VO, dtype=jnp.int32) * ATOM_TYPES["VO"],
    ])

    noise_b = random.uniform(key_b, (batch_size, n_B))
    noise_o = random.uniform(key_o, (batch_size, n_Osite))
    b_configs = b_template[jnp.argsort(noise_b, axis=-1)]
    o_configs = o_template[jnp.argsort(noise_o, axis=-1)]

    a_configs = jnp.ones((batch_size, n_Sr), dtype=jnp.int32) * ATOM_TYPES["Sr"]

    return jnp.concatenate([a_configs, b_configs, o_configs], axis=-1)


# =============================================================================
# Core Swap Operations
# =============================================================================

@jax.jit
def swap_by_idx(x: jnp.ndarray, idx: jnp.ndarray) -> jnp.ndarray:
    """
    Swap elements at specified indices.
    Args:
        x: [batch, N] tensor
        idx: [batch, 2] indices to swap
    Returns:
        x_swapped: [batch, N]
    """
    batch_size = x.shape[0]
    batch_idx = jnp.arange(batch_size)
    idx_a, idx_b = idx[:, 0], idx[:, 1]
    val_a = x[batch_idx, idx_a]
    val_b = x[batch_idx, idx_b]
    x_swapped = x.at[batch_idx, idx_a].set(val_b)
    x_swapped = x_swapped.at[batch_idx, idx_b].set(val_a)
    return x_swapped


@partial(jax.jit, static_argnums=(2, 3))
def sample_sublattice_swap(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: Optional[jnp.ndarray] = None,
) -> SwapResult:
    """
    Sample one constrained swap per batch element.
    Args:
        key: JAX random key
        atom_types: [batch, N] atom type indices
        sublattice_indices: tuple of indices in sublattice (static for JIT)
        type_a, type_b: types to swap
        scores: [batch, N] swap scores (None = uniform)
    Returns:
        SwapResult with swapped tensor and indices
    """
    sub_idx = jnp.array(sublattice_indices)
    batch_size, N = atom_types.shape
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    sub_scores = jnp.zeros((batch_size, M)) if scores is None else scores[:, sub_idx]

    key_a, key_b = random.split(key)
    gumbel_a = random.gumbel(key_a, (batch_size, M))
    gumbel_b = random.gumbel(key_b, (batch_size, M))

    score_a = jnp.where(is_a, sub_scores + gumbel_a, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores + gumbel_b, -jnp.inf)

    local_a = jnp.argmax(score_a, axis=-1)
    local_b = jnp.argmax(score_b, axis=-1)

    global_a = sub_idx[local_a]
    global_b = sub_idx[local_b]
    indices = jnp.stack([global_a, global_b], axis=-1)

    swapped = swap_by_idx(atom_types, indices)
    return SwapResult(swapped=swapped, indices=indices)


@partial(jax.jit, static_argnums=(2, 3))
def sample_sublattice_swap_deterministic(
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: jnp.ndarray,
) -> SwapResult:
    """Deterministic swap (no noise, pick highest scores)."""
    sub_idx = jnp.array(sublattice_indices)
    batch_size = atom_types.shape[0]
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    sub_scores = scores[:, sub_idx]

    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    score_a = jnp.where(is_a, sub_scores, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores, -jnp.inf)

    local_a = jnp.argmax(score_a, axis=-1)
    local_b = jnp.argmax(score_b, axis=-1)

    global_a = sub_idx[local_a]
    global_b = sub_idx[local_b]
    indices = jnp.stack([global_a, global_b], axis=-1)

    swapped = swap_by_idx(atom_types, indices)
    return SwapResult(swapped=swapped, indices=indices)


# =============================================================================
# Multiple Swap Steps
# =============================================================================

@partial(jax.jit, static_argnums=(3, 4, 5, 6))
def apply_n_swaps(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    scores: Optional[jnp.ndarray],
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    n_swaps: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Apply n swap steps using lax.scan (single sublattice).
    Args:
        key: random key
        atom_types: [batch, N]
        scores: [batch, N] or None
        sublattice_indices: tuple
        type_a, type_b: types
        n_swaps: number of steps
    Returns:
        final: [batch, N]
        all_indices: [n_swaps, batch, 2]
    """
    keys = random.split(key, n_swaps)

    def scan_fn(carry, key_i):
        x = carry
        result = sample_sublattice_swap(
            key_i, x, sublattice_indices, type_a, type_b, scores
        )
        return result.swapped, result.indices

    final, all_indices = lax.scan(scan_fn, atom_types, keys)
    return final, all_indices


@partial(jax.jit, static_argnums=(3, 4, 5, 6, 7, 8, 9))
def apply_n_swaps_both(
    key: random.PRNGKey,
    atom_types: jnp.ndarray,
    scores: Optional[jnp.ndarray],
    b_site_indices: Tuple[int, ...],
    o_site_indices: Tuple[int, ...],
    type_ti: int,
    type_fe: int,
    type_o: int,
    type_vo: int,
    n_swaps: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Apply n swap steps with random B/O site selection (like PyTorch 'both' mode).
    Args:
        key: random key
        atom_types: [batch, N]
        scores: [batch, N] or None
        b_site_indices: B-site tuple
        o_site_indices: O-site tuple
        type_ti, type_fe: B-site types
        type_o, type_vo: O-site types
        n_swaps: number of steps
    Returns:
        final: [batch, N]
        all_indices: [n_swaps, batch, 2]
    """
    keys = random.split(key, n_swaps * 2).reshape(n_swaps, 2, 2)

    b_idx = jnp.array(b_site_indices)
    o_idx = jnp.array(o_site_indices)

    def scan_fn(carry, keys_i):
        x = carry
        key_choice, key_swap = keys_i[0], keys_i[1]

        # Random choice: B-site or O-site (50/50)
        do_b = random.uniform(key_choice) < 0.5

        # Compute both swaps
        result_b = sample_sublattice_swap(
            key_swap, x, b_site_indices, type_ti, type_fe, scores
        )
        result_o = sample_sublattice_swap(
            key_swap, x, o_site_indices, type_o, type_vo, scores
        )

        # Select based on random choice
        swapped = jnp.where(do_b, result_b.swapped, result_o.swapped)
        indices = jnp.where(do_b, result_b.indices, result_o.indices)

        return swapped, indices

    final, all_indices = lax.scan(scan_fn, atom_types, keys)
    return final, all_indices


# =============================================================================
# Beam Search
# =============================================================================

@partial(jax.jit, static_argnums=(2, 3, 5))
def sample_sublattice_swap_beam(
    atom_types: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    scores: jnp.ndarray,
    beam_size: int = 4,
) -> BeamResult:
    """
    Beam search for top-k constrained swap candidates.
    Args:
        atom_types: [batch, N]
        sublattice_indices: tuple of sublattice indices
        type_a, type_b: types to swap
        scores: [batch, N] swap scores
        beam_size: number of candidates
    Returns:
        BeamResult with candidates, indices, and log_probs
    """
    sub_idx = jnp.array(sublattice_indices)
    batch_size, N = atom_types.shape
    M = len(sub_idx)

    sub_types = atom_types[:, sub_idx]
    sub_scores = scores[:, sub_idx]

    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    score_a = jnp.where(is_a, sub_scores, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores, -jnp.inf)

    top_scores_a, top_local_a = lax.top_k(score_a, beam_size)
    top_scores_b, top_local_b = lax.top_k(score_b, beam_size)

    pair_scores = top_scores_a[:, :, None] + top_scores_b[:, None, :]
    pair_scores_flat = pair_scores.reshape(batch_size, -1)
    top_pair_scores, top_pair_idx = lax.top_k(pair_scores_flat, beam_size)

    idx_a_beam = top_pair_idx // beam_size
    idx_b_beam = top_pair_idx % beam_size

    batch_idx = jnp.arange(batch_size)[:, None]
    local_a_beam = top_local_a[batch_idx, idx_a_beam]
    local_b_beam = top_local_b[batch_idx, idx_b_beam]

    global_a_beam = sub_idx[local_a_beam]
    global_b_beam = sub_idx[local_b_beam]
    indices_candidates = jnp.stack([global_a_beam, global_b_beam], axis=-1)

    def swap_single_beam(x, idx):
        return swap_by_idx(x[None, :], idx[None, :])[0]
    def swap_all_beams(x, indices):
        return jax.vmap(lambda idx: swap_single_beam(x, idx))(indices)
    swapped_candidates = jax.vmap(swap_all_beams)(atom_types, indices_candidates)

    log_prob_a = jax.nn.log_softmax(score_a, axis=-1)
    log_prob_b = jax.nn.log_softmax(score_b, axis=-1)
    lp_a = log_prob_a[batch_idx, local_a_beam]
    lp_b = log_prob_b[batch_idx, local_b_beam]
    log_probs = lp_a + lp_b

    return BeamResult(swapped=swapped_candidates, indices=indices_candidates, log_probs=log_probs)


# =============================================================================
# Log Probability
# =============================================================================

@partial(jax.jit, static_argnums=(1, 2, 3))
def log_prob_sublattice_swap(
    scores: jnp.ndarray,
    sublattice_indices: Tuple[int, ...],
    type_a: int,
    type_b: int,
    atom_types: jnp.ndarray,
    swap_indices: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute log P(swap_indices | scores).
    Args:
        scores: [batch, N]
        sublattice_indices: tuple
        type_a, type_b: swapped types
        atom_types: [batch, N] BEFORE swap
        swap_indices: [batch, 2]
    Returns:
        log_probs: [batch]
    """
    sub_idx = jnp.array(sublattice_indices)
    batch_size = scores.shape[0]

    sub_types = atom_types[:, sub_idx]
    sub_scores = scores[:, sub_idx]

    is_a = (sub_types == type_a)
    is_b = (sub_types == type_b)

    score_a = jnp.where(is_a, sub_scores, -jnp.inf)
    score_b = jnp.where(is_b, sub_scores, -jnp.inf)

    log_prob_a = jax.nn.log_softmax(score_a, axis=-1)
    log_prob_b = jax.nn.log_softmax(score_b, axis=-1)

    global_a = swap_indices[:, 0]
    global_b = swap_indices[:, 1]

    local_a = jnp.argmax(sub_idx[None, :] == global_a[:, None], axis=-1)
    local_b = jnp.argmax(sub_idx[None, :] == global_b[:, None], axis=-1)

    batch_idx = jnp.arange(batch_size)
    lp_a = log_prob_a[batch_idx, local_a]
    lp_b = log_prob_b[batch_idx, local_b]

    return lp_a + lp_b


# =============================================================================
# Utility
# =============================================================================

def get_sublattice_indices(
    atom_types: jnp.ndarray,
    target_types: Tuple[int, ...],
) -> Tuple[int, ...]:
    """Get indices where atom_types is in target_types."""
    mask = jnp.zeros(atom_types.shape, dtype=bool)
    for t in target_types:
        mask = mask | (atom_types == t)
    indices = jnp.where(mask)[0]
    return tuple(int(i) for i in indices)


# =============================================================================
# Benchmark
# =============================================================================

def benchmark(name, fn, *args, n_warmup=3, n_trials=10):
    """Run benchmark with warmup"""

    def wait_ready(out):
        if hasattr(out, 'swapped'):
            out.swapped.block_until_ready()
        elif isinstance(out, tuple):
            out[0].block_until_ready()
        elif hasattr(out, 'block_until_ready'):
            out.block_until_ready()

    for _ in range(n_warmup):
        out = fn(*args) if args else fn()
        wait_ready(out)

    times = []
    for _ in range(n_trials):
        start = time.perf_counter()
        out = fn(*args) if args else fn()
        wait_ready(out)
        times.append(time.perf_counter() - start)

    avg = sum(times) / len(times)
    std = (sum((t - avg)**2 for t in times) / len(times)) ** 0.5
    print(f"{name:40s}: {avg*1000:10.2f} ms ± {std*1000:.2f}")
    return avg


def run_benchmark():
    """Run benchmark suite"""

    BATCH_SIZE = 10000
    N_SWAPS = 10000
    BEAM_SIZE = 4

    composition = {"Sr": 32, "Ti": 8, "Fe": 24, "O": 84, "VO": 12}
    N = sum(composition.values())
    n_B = composition["Ti"] + composition["Fe"]
    n_O = composition["O"] + composition["VO"]

    sr_end = int(composition["Sr"])
    b_end = int(sr_end + n_B)
    b_site_idx = tuple(range(sr_end, b_end))
    o_site_idx = tuple(range(b_end, int(N)))

    print(f"\n{'='*60}")
    print(f"BENCHMARK: batch={BATCH_SIZE}, n_swaps={N_SWAPS}, N={N}")
    print(f"{'='*60}")
    print(f"B-site: {len(b_site_idx)} positions (Ti={composition['Ti']}, Fe={composition['Fe']})")
    print(f"O-site: {len(o_site_idx)} positions (O={composition['O']}, VO={composition['VO']})")

    key = random.PRNGKey(42)
    key, subkey = random.split(key)

    structures = generate_structures(
        subkey, BATCH_SIZE,
        composition["Sr"], composition["Ti"], composition["Fe"],
        composition["O"], composition["VO"]
    )
    structures.block_until_ready()

    key, subkey = random.split(key)
    scores = random.normal(subkey, structures.shape)

    # Single swap
    print(f"\n[1] swap_by_idx")
    key, subkey = random.split(key)
    idx = random.randint(subkey, (BATCH_SIZE, 2), 0, N)
    benchmark("swap_by_idx", swap_by_idx, structures, idx)

    # Sublattice swap
    print(f"\n[2] sample_sublattice_swap")
    key, subkey = random.split(key)
    benchmark("B-site swap",
              lambda: sample_sublattice_swap(subkey, structures, b_site_idx, 1, 2, scores))
    key, subkey = random.split(key)
    benchmark("O-site swap",
              lambda: sample_sublattice_swap(subkey, structures, o_site_idx, 3, 4, scores))

    # N swaps - B-site only
    print(f"\n[3] apply_n_swaps B-site only ({N_SWAPS}x)")
    key, subkey = random.split(key)
    benchmark("B-site only",
              apply_n_swaps, subkey, structures, scores, b_site_idx, 1, 2, N_SWAPS,
              n_warmup=2, n_trials=5)

    # N swaps - BOTH mode (공정한 비교!)
    print(f"\n[4] apply_n_swaps BOTH mode ({N_SWAPS}x) ⭐ PyTorch 비교용")
    key, subkey = random.split(key)
    benchmark("BOTH (B+O random)",
              apply_n_swaps_both, subkey, structures, scores,
              b_site_idx, o_site_idx, 1, 2, 3, 4, N_SWAPS,
              n_warmup=2, n_trials=5)

    # Beam search
    print(f"\n[5] beam_search (k={BEAM_SIZE})")
    benchmark("B-site beam",
              sample_sublattice_swap_beam, structures, b_site_idx, 1, 2, scores, BEAM_SIZE)
    benchmark("O-site beam",
              sample_sublattice_swap_beam, structures, o_site_idx, 3, 4, scores, BEAM_SIZE)


# =============================================================================
# Example
# =============================================================================

def example():
    """Example usage"""
    key = random.PRNGKey(42)

    composition = {"Sr": 32, "Ti": 8, "Fe": 24, "O": 84, "VO": 12}

    key, subkey = random.split(key)
    structures = generate_structures(
        subkey, 4,
        composition["Sr"], composition["Ti"], composition["Fe"],
        composition["O"], composition["VO"]
    )

    b_site_idx = tuple(range(32, 64))

    key, subkey = random.split(key)
    scores = random.normal(subkey, structures.shape)

    print("=== Example ===")
    print(f"Structures shape: {structures.shape}")

    # Single swap
    key, subkey = random.split(key)
    result = sample_sublattice_swap(subkey, structures, b_site_idx, 1, 2, scores)
    print(f"Swap indices: {result.indices}")

    # Beam search
    beam = sample_sublattice_swap_beam(structures, b_site_idx, 1, 2, scores, 4)
    print(f"Beam candidates shape: {beam.swapped.shape}")
    print(f"Beam log_probs: {beam.log_probs[0]}")

    # N swaps
    key, subkey = random.split(key)
    final, all_idx = apply_n_swaps(subkey, structures, scores, b_site_idx, 1, 2, 10)
    print(f"After 10 swaps: {final.shape}")


if __name__ == "__main__":
    example()
    run_benchmark()



JAX Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
=== Example ===
Structures shape: (4, 160)
Swap indices: [[60 57]
 [59 63]
 [42 46]
 [61 63]]
Beam candidates shape: (4, 4, 160)
Beam log_probs: [-2.664374  -2.982929  -3.3346233 -3.4151826]
After 10 swaps: (4, 160)

BENCHMARK: batch=10000, n_swaps=10000, N=160
B-site: 32 positions (Ti=8, Fe=24)
O-site: 96 positions (O=84, VO=12)

[1] swap_by_idx
swap_by_idx                             :       0.58 ms ± 0.02

[2] sample_sublattice_swap
B-site swap                             :       0.67 ms ± 0.01
O-site swap                             :       0.86 ms ± 0.01

[3] apply_n_swaps B-site only (10000x)
B-site only                             :    3614.92 ms ± 0.41

[4] apply_n_swaps BOTH mode (10000x) ⭐ PyTorch 비교용
BOTH (B+O random)                       :    8956.15 ms ± 0.26

[5] beam_search (k=4)
B-site beam                             :       3.89 ms ± 0.08
O-site beam                             :       6.