# PSMSL: Geometric Energy-Based Models on THRML

This notebook demonstrates how to use PSMSL (Projected Symmetry Mirrored Semantic Lattice) models within the THRML framework. PSMSL encodes geometric constraints—mirror symmetry, phi-scaling, and spatial structure—into energy-based models that can run on thermodynamic sampling units.

## Key Concepts

- **Data Plane**: Observable spin variables arranged in a 2D grid
- **Latent Plane**: Hidden spin variables colocated with data plane
- **Mirror Pairs**: Geometric constraints linking spins via symmetry
- **Dyad Ties**: Vertical coupling between data and latent planes
- **Phi-Scaling**: Golden ratio-based indexing for mirror relationships

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from thrml.models import PSMSLConfig, build_psmsl_model, SpinGibbsConditional
from thrml.block_sampling import BlockSamplingProgram, sample_states
from thrml.models.psmsl import DataSpin, LatentSpin

## 1. Build a PSMSL Model

We'll create a small PSMSL model with configurable geometric constraints.

In [None]:
# Configure PSMSL model
config = PSMSLConfig(
    rows=8,
    cols=8,
    j_local=0.6,    # Data-to-data coupling
    j_latent=0.4,   # Latent-to-latent coupling
    j_dyad=0.8,     # Data-to-latent coupling
    j_mirror=0.25,  # Mirror symmetry coupling
    h_data=0.0,     # Data bias
    h_latent=0.05,  # Latent bias
)

# Build model with phi-scaling mirror mode
factor, free_blocks, node_shape_dtypes = build_psmsl_model(
    config, 
    mirror_mode="phi"  # Options: "simple", "phi", "reflect"
)

print(f"Created PSMSL model with {len(factor.nodes)} nodes")
print(f"Number of pairwise couplings: {len(factor.pair_indices)}")
print(f"Free blocks for parallel sampling: {len(free_blocks)}")

## 2. Create Sampling Program

Build the block sampling program using Gibbs conditionals.

In [None]:
# Define conditional samplers
conditionals = {
    DataSpin: SpinGibbsConditional(),
    LatentSpin: SpinGibbsConditional(),
}

# Create sampling program
program = BlockSamplingProgram(
    free_blocks=free_blocks,
    clamped_blocks=[],
    node_shape_dtypes=node_shape_dtypes,
    factors=[factor],
    conditionals=conditionals,
)

print("Sampling program created successfully")

## 3. Initialize and Sample

Initialize random states and run Gibbs sampling.

In [None]:
# Initialize random states
n_chains = 16
rng = jax.random.key(42)

init_states = []
keys = jax.random.split(rng, len(free_blocks) * n_chains + 1)
ki = 0

for block in free_blocks:
    block_shape = (n_chains, len(block.nodes))
    spins = jax.random.choice(
        keys[ki], 
        jnp.array([-1, 1], dtype=jnp.int8), 
        shape=block_shape
    )
    init_states.append(spins)
    ki += 1

print(f"Initialized {n_chains} chains")

In [None]:
# Run sampling
n_steps = 500
sample_keys = jax.random.split(rng, n_chains)

final_states, observations = sample_states(
    program=program,
    init_state=init_states,
    schedule=program.default_schedule(),
    n_steps=n_steps,
    keys=sample_keys,
    observers=[],
    return_observations=True,
)

print(f"Completed {n_steps} sampling steps")

## 4. Visualize Results

Plot the mean spin values for data and latent planes.

In [None]:
def plot_spin_field(states, rows, cols, title):
    """Plot 2D field of spin values."""
    # Average over chains
    mean_spins = np.array(states).mean(axis=0)
    
    # Reshape to grid (assuming bipartite blocks)
    # This is simplified - actual reshaping depends on block structure
    field = mean_spins[:rows*cols//2].reshape(rows//2, cols)
    
    plt.figure(figsize=(6, 5))
    plt.imshow(field, cmap='RdBu', vmin=-1, vmax=1)
    plt.colorbar(label='Mean spin')
    plt.title(title)
    plt.xlabel('Column')
    plt.ylabel('Row')
    plt.tight_layout()
    plt.show()

# Plot data plane (first two blocks)
plot_spin_field(final_states[0], config.rows, config.cols, "Data Plane (Even Block)")
plot_spin_field(final_states[1], config.rows, config.cols, "Data Plane (Odd Block)")

# Plot latent plane (last two blocks)
plot_spin_field(final_states[2], config.rows, config.cols, "Latent Plane (Even Block)")
plot_spin_field(final_states[3], config.rows, config.cols, "Latent Plane (Odd Block)")

## 5. Analyze Geometric Structure

Examine how mirror constraints affect the spin configuration.

In [None]:
# Calculate correlation between mirrored positions
# This would require tracking mirror pairs and computing correlations
# Placeholder for analysis

print("Geometric structure analysis:")
print(f"  Mirror coupling strength: {config.j_mirror}")
print(f"  Dyad coupling strength: {config.j_dyad}")
print(f"  Expected correlation between mirrored spins: high")
print(f"  Expected correlation between data-latent dyads: high")

## 6. Multi-Layer Denoising

Demonstrate progressive refinement through multiple sampling layers.

In [None]:
from thrml.models import PSMSLDenoiser

# Create denoiser
denoiser = PSMSLDenoiser(
    config=config,
    layers=3,
    steps_per_layer=200,
    mirror_mode="phi",
)

# Get sampling program
denoise_program = denoiser.get_sampling_program()

# Run denoising layers
state = init_states
trajectory = []

for layer in range(denoiser.layers):
    state, _ = sample_states(
        program=denoise_program,
        init_state=state,
        schedule=denoise_program.default_schedule(),
        n_steps=denoiser.steps_per_layer,
        keys=sample_keys,
        observers=[],
        return_observations=True,
    )
    trajectory.append(state)
    print(f"Completed denoising layer {layer + 1}/{denoiser.layers}")

# Visualize progression
for i, state in enumerate(trajectory):
    plot_spin_field(state[0], config.rows, config.cols, f"Layer {i+1} - Data Plane")

## Summary

This notebook demonstrated:

1. **Building PSMSL models** with geometric constraints
2. **Configuring mirror modes** (simple, phi-scaling, reflection)
3. **Sampling with block Gibbs** for parallel updates
4. **Multi-layer denoising** for progressive refinement
5. **Visualizing** data and latent planes

PSMSL provides a framework for incorporating geometric structure into energy-based models, enabling applications in:
- Physics simulation with symmetry constraints
- Generative modeling with geometric priors
- Constrained optimization problems
- Pattern generation with spatial structure

The integration with THRML means these models can run on GPU simulations today and TSU hardware tomorrow, potentially with orders of magnitude energy efficiency improvements.