
# Moiré Cavity Pipeline — Full Refactor (v2)

This notebook follows the exact flow you described:

1) **Set parameters** manually for the best candidate (lattice + parameters, target band, target k-point).  
2) **Compute monolayer band diagram** once, highlight the target band and target k-point.  
3) **Construct a test moiré lattice** at 1.1° and **plot the stacking register**.  
4) **Construct mini moiré unit cells** for AA/AB/BA with sizes of **1×, 2×, 3×** the moiré lattice constant(s); visualize.  
5) **Compute band diagrams** (same polarization, 8 bands by default) **for all 9 cases** (3 registries × 3 sizes).  
   - For the moiré unit-cell band diagrams, **use the moiré lattice vectors from your Rust framework** (see TODO in code).  
   - All band plots highlight the target band.  
   - The 9 plots are saved individually and also collated into a **3×3 montage image** for convenience.  
6) **Optimize the moiré angle** (bounded 1D) using a physically grounded score (registry contrast + envelope bound states).  
   - For the **best angle**, compute **AA/AB/BA band diagrams** again and highlight the target band.

> **Note:** This notebook relies on your `moire_lattice_py` bindings for lattice/BZ utilities. Where API details are unknown, there are clear **TODO** markers to drop your existing calls.


In [1]:

import numpy as np
import math, os, io, warnings
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, List

import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from PIL import Image

# IMPORTANT: matplotlib rules for this environment
# - use matplotlib (no seaborn)
# - one chart per figure (no subplots)
# - do not set explicit colors unless asked; we rely on defaults

try:
    import meep as mp
    from meep import mpb
    MEEP_AVAILABLE = True
except Exception:
    MEEP_AVAILABLE = False

try:
    import moire_lattice_py as ml
    ML_AVAILABLE = True
except Exception:
    ML_AVAILABLE = False

print("Meep/MPB available:", MEEP_AVAILABLE)
print("moire_lattice_py available:", ML_AVAILABLE)


Using MPI version 4.1, 1 processes
Meep/MPB available: True
moire_lattice_py available: True


## 1) Set candidate parameters

In [2]:

@dataclass
class Candidate:
    lattice_type: str   # 'triangular' | 'square' | 'rectangular' | 'oblique'
    a1: np.ndarray      # real-space basis vector 1
    a2: np.ndarray      # real-space basis vector 2
    hole_radius_a: float
    eps_bg: float
    polarization: str   # 'TE' or 'TM'
    band_index: int     # 0-based
    k_label: str        # e.g., 'Γ', 'X', 'M', 'K', ''

# ==== USER INPUTS ====
cand = Candidate(
    lattice_type="square",
    a1=np.array([1.0, 0.0]),
    a2=np.array([0.0, 1.0]),
    hole_radius_a=0.43,
    eps_bg=4.02,
    polarization="TM",
    band_index=3,
    k_label="M"
)

num_bands = 8           # how many bands to compute/plot
resolution = 40         # MPB resolution (pixels/a)
kpts_per_seg = 18       # points per segment on HS path
output_dir = "moire_pipeline_outputs"
os.makedirs(output_dir, exist_ok=True)


## Utilities: lattice builders and high-symmetry paths

In [3]:

def rotation_matrix(theta_deg: float) -> np.ndarray:
    th = math.radians(theta_deg)
    return np.array([[math.cos(th), -math.sin(th)],
                     [math.sin(th),  math.cos(th)]], dtype=float)

def lattice_matrix(a1: np.ndarray, a2: np.ndarray) -> np.ndarray:
    A = np.zeros((2,2), dtype=float); A[:,0] = a1; A[:,1] = a2
    return A

def mp_lattice_from_vectors(a1: np.ndarray, a2: np.ndarray):
    return mp.Lattice(size=mp.Vector3(1,1,0),
                      basis1=mp.Vector3(float(a1[0]), float(a1[1]), 0),
                      basis2=mp.Vector3(float(a2[0]), float(a2[1]), 0))

def hs_path_for_lattice(lattice_type: str, k_points_per_segment: int):
    lt = lattice_type.lower()
    labels = []
    label_inds = []

    def add_segment(p_from, p_to, lbl_from, lbl_to, pts):
        seg = mp.interpolate(k_points_per_segment, [p_from, p_to])
        if len(pts) > 0: seg = seg[1:]
        start_idx = len(pts)
        pts += seg
        if lbl_from is not None:
            if len(labels) == 0:
                labels.append(lbl_from); label_inds.append(start_idx)
        labels.append(lbl_to); label_inds.append(len(pts)-1)
        return pts

    pts = []
    if lt == "triangular":
        G = mp.Vector3(0,0); M = mp.Vector3(0.5,0.0); K = mp.Vector3(1/3,1/3)
        pts = add_segment(G,M,"Γ","M",pts)
        pts = add_segment(M,K,"M","K",pts)
        pts = add_segment(K,G,"K","Γ",pts)
    elif lt in ("square","rectangular"):
        G = mp.Vector3(0,0); X = mp.Vector3(0.5,0.0); M = mp.Vector3(0.5,0.5)
        pts = add_segment(G,X,"Γ","X",pts)
        pts = add_segment(X,M,"X","M",pts)
        pts = add_segment(M,G,"M","Γ",pts)
    else:
        # Oblique path (placeholder). For thorough oblique analysis, use BZ mesh in your other script.
        G = mp.Vector3(0,0); B1 = mp.Vector3(0.5,0.0); B2 = mp.Vector3(0.0,0.5)
        pts = add_segment(G,B1,"Γ","b1/2",pts)
        pts = add_segment(B1,B2,"b1/2","b2/2",pts)
        pts = add_segment(B2,G,"b2/2","Γ",pts)
    return pts, labels, label_inds

def path_arclength(kpts):
    ks = np.array([[kp.x, kp.y] for kp in kpts], dtype=float)
    if len(ks) < 2: return np.zeros(len(ks))
    ds = np.linalg.norm(np.diff(ks, axis=0), axis=1)
    return np.concatenate([[0.0], np.cumsum(ds)])


## MPB helpers and plotting

In [4]:
# Plotting helper for band diagrams used in step 5

def plot_band_diagram(result: dict, lattice_type: str, target_band: int, k_label: str, outfile: str, title_prefix: str = ""):
    import matplotlib.pyplot as plt
    s = result["s"]
    # Choose which frequency array to plot
    freqs = result.get("freqs")
    if freqs is None:
        freqs = result.get("freqs_tm")
    if freqs is None:
        freqs = result.get("freqs_te")
    if freqs is None:
        raise ValueError("No frequency data found in result (expected 'freqs' or 'freqs_tm'/'freqs_te').")

    k_labels = result.get("k_labels", [])
    k_inds = result.get("k_inds", [])

    plt.figure(figsize=(6, 4))
    # Plot all bands
    nbands = freqs.shape[1] if freqs.ndim == 2 else len(freqs)
    if freqs.ndim == 2:
        for b in range(nbands):
            lw = 2.5 if b == target_band else 1.2
            alpha = 1.0 if b == target_band else 0.9
            plt.plot(s, freqs[:, b], linewidth=lw, alpha=alpha, color=(0.1, 0.1, 0.1))
    else:
        lw = 2.5 if 0 == target_band else 1.2
        plt.plot(s, freqs, linewidth=lw)

    # High-symmetry ticks and guides
    if len(k_inds) == len(k_labels) and len(k_inds) > 0:
        for xi in k_inds:
            plt.axvline(s[xi], color="0.8", linewidth=0.8)
        plt.xticks([s[i] for i in k_inds], k_labels)
    plt.xlabel("k-path")
    plt.ylabel("Frequency (c/a)")
    ttl = f"{title_prefix} bands ({lattice_type})"
    if k_label:
        ttl += f" — target: {k_label}"
    plt.title(ttl)
    plt.grid(True, alpha=0.2)
    plt.tight_layout()
    plt.savefig(outfile, dpi=160)
    plt.close()

In [5]:
# Geometry builder using reciprocal-based basis

def mpb_geometry_for_factor(L1_py, L2_py, factor: int, radius: float, eps_high: float, eps_low: float):
    import meep as mp
    # L1 real-space lattice vectors for MPB lattice
    (a1, a2) = L1_py.lattice_vectors()
    lat = mp.Lattice(size=mp.Vector3(1, 1), basis1=mp.Vector3(a1[0], a1[1]), basis2=mp.Vector3(a2[0], a2[1]))
    # Compute basis in fractional coordinates of L1 from reciprocal strays
    basis_frac = mpb_basis_from_reciprocal_strays(L1_py, L2_py, factor)
    # Build geometry: one high-index cylinder at each basis site
    geometry = []
    for (u, v) in basis_frac:
        center = mp.Vector3(u, v)
        geometry.append(mp.Cylinder(radius=radius, material=mp.Medium(epsilon=eps_high), center=center, axis=mp.Vector3(0, 0, 1)))
    default_material = mp.Medium(epsilon=eps_low)
    return lat, geometry, default_material, basis_frac

## 2) Monolayer band diagram

In [6]:
# Reciprocal-space helpers
import numpy as np
from typing import List, Tuple, Optional


def reciprocal_rectangle_for_lattice(lattice_py, factor: int = 1) -> Tuple[float, float]:
    """
    Build an axis-aligned rectangle in reciprocal space whose area equals
    (factor * area of L1's reciprocal unit cell). We choose a square of side sqrt(area).
    Returns (width, height).
    """
    (g1, g2) = lattice_py.reciprocal_vectors()  # ((gx1, gy1), (gx2, gy2))
    g1 = np.array(g1)
    g2 = np.array(g2)
    # Parallelogram area = |G1 x G2|
    area = abs(g1[0]*g2[1] - g1[1]*g2[0])
    rect_area = max(1e-16, factor * float(area))
    side = float(np.sqrt(rect_area))
    width = side
    height = side
    return width, height


def fold_into_reciprocal_cell(L1_py, kx: float, ky: float) -> Tuple[float, float]:
    """
    Reduce a k-point into L1's first Brillouin zone using the lattice API if available,
    else fold by solving fractional coords in reciprocal basis and wrapping to [-0.5, 0.5).
    """
    try:
        rx, ry, _ = L1_py.reduce_to_brillouin_zone(kx, ky)
        return rx, ry
    except Exception:
        # Fallback folding using reciprocal basis
        (G1, G2) = L1_py.reciprocal_vectors()
        B = np.column_stack([G1, G2])  # 2x2
        frac = np.linalg.solve(B, np.array([kx, ky]))
        frac_wrapped = frac - np.round(frac)
        cart = B @ frac_wrapped
        return float(cart[0]), float(cart[1])


def reciprocal_stray_points(L1_py, L2_py, factor: int = 1) -> List[Tuple[float, float]]:
    """
    Get reciprocal lattice points of L2 within a rectangle whose area matches
    factor × (area of L1 reciprocal cell), fold into L1's 1st BZ. Return unique
    folded points excluding Γ (0,0).
    """
    width, height = reciprocal_rectangle_for_lattice(L1_py, factor)
    pts = L2_py.get_reciprocal_lattice_points_in_rectangle(width, height)
    folded = []
    for (kx, ky, _) in pts:
        rx, ry = fold_into_reciprocal_cell(L1_py, kx, ky)
        if abs(rx) < 1e-12 and abs(ry) < 1e-12:
            continue  # skip Gamma
        folded.append((rx, ry))
    # Uniquify with a tolerance
    uniq = []
    tol = 1e-9
    for p in folded:
        if not any((abs(p[0]-q[0]) < tol and abs(p[1]-q[1]) < tol) for q in uniq):
            uniq.append(p)
    return uniq


def mpb_basis_from_reciprocal_strays(
    L1_py,
    L2_py,
    factor: int = 1,
    include_all_strays: bool = False,
    max_additional: Optional[int] = None,
    min_frac_sep: float = 0.0,
) -> List[Tuple[float, float]]:
    """
    Map reciprocal stray points to a real-space fractional basis for MPB.
    - Always include the L1 site at (0,0).
    - Selection policy:
        * include_all_strays=True  -> add all stray-derived sites
        * else if max_additional   -> cap extras to max_additional
        * else                     -> soft cap at max(1, factor)
    - min_frac_sep: prune added sites with a minimal torus separation in fractional coords.
    """
    strays = reciprocal_stray_points(L1_py, L2_py, factor)
    basis = [(0.0, 0.0)]
    if not strays:
        return basis

    def k_to_frac(kx, ky) -> Tuple[float, float]:
        (G1, G2) = L1_py.reciprocal_vectors()
        B = np.column_stack([G1, G2])
        frac = np.linalg.solve(B, np.array([kx, ky]))  # reciprocal fractional coords
        # Use fractional coords directly and wrap to [0,1)
        u = float(frac[0] % 1.0)
        v = float(frac[1] % 1.0)
        return u, v

    # Decide how many to include
    if include_all_strays:
        picked = strays
    else:
        soft_cap = max(1, factor)
        cap = max_additional if (max_additional is not None) else soft_cap
        picked = strays[: min(len(strays), cap)]

    # Map to fractional coords and apply minimal separation pruning on torus
    def torus_dist(p, q):
        du = abs(p[0] - q[0])
        dv = abs(p[1] - q[1])
        du = min(du, 1.0 - du)
        dv = min(dv, 1.0 - dv)
        return np.hypot(du, dv)

    added: List[Tuple[float, float]] = []
    for kx, ky in picked:
        u, v = k_to_frac(kx, ky)
        if min_frac_sep > 0.0:
            if any(torus_dist((u, v), q) < min_frac_sep for q in added):
                continue
        added.append((u, v))

    basis.extend(added)
    return basis


def plot_reciprocal_registry(L1_py, L2_py, factor: int = 1, ax=None):
    import matplotlib.pyplot as plt
    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))
    # Draw rectangle (area-matched) for visualization
    width, height = reciprocal_rectangle_for_lattice(L1_py, factor)
    ax.add_patch(plt.Rectangle((-width/2, -height/2), width, height, fill=False, ls='--', lw=1.0, color='gray'))
    # Plot L2 reciprocal points in the rectangle
    pts = L2_py.get_reciprocal_lattice_points_in_rectangle(width, height)
    if pts:
        xs = [p[0] for p in pts]
        ys = [p[1] for p in pts]
        ax.scatter(xs, ys, s=8, c='tab:blue', alpha=0.5, label='L2 G-points')
    # Plot folded stray points
    strays = reciprocal_stray_points(L1_py, L2_py, factor)
    if strays:
        ax.scatter([p[0] for p in strays], [p[1] for p in strays], s=30, c='tab:red', label='Folded strays')
    ax.scatter([0.0], [0.0], c='k', s=20, label='Γ')
    ax.set_aspect('equal', adjustable='box')
    ax.legend(loc='upper right', fontsize=8)
    ax.set_title(f"Reciprocal registry (factor={factor})")
    return ax

## 3) Test moiré lattice at 1.1° and stacking register plot

In [7]:
theta_test_deg = 3

def plot_stacking_register(a1: np.ndarray, a2: np.ndarray, theta_deg: float, outfile: str):
    # Create base lattice using moire_lattice_py framework
    if not ML_AVAILABLE:
        print("Warning: moire_lattice_py not available, using placeholder visualization")
        # Fallback to placeholder
        A = np.stack([a1, a2], axis=1)
        R = rotation_matrix(theta_deg)
        B = R @ A
        # sample a small set of lattice points
        def pts(a1,a2,rad=3):
            P = []
            for m in range(-rad,rad+1):
                for n in range(-rad,rad+1):
                    P.append(m*a1 + n*a2)
            return np.array(P)
        P1 = pts(a1,a2,30); P2 = pts(B[:,0],B[:,1],30)
        plt.figure(figsize=(6,6))
        plt.scatter(P1[:,0], P1[:,1], s=10)  # layer 1
        plt.scatter(P2[:,0], P2[:,1], s=10)  # layer 2
        plt.gca().set_aspect("equal","box")
        plt.title(f"Stacking register (heuristic view) @ {theta_deg:.2f}°")
        plt.tight_layout(); plt.savefig(outfile, dpi=140); plt.close()
        return
    
    # Create square lattice (assuming unit cell with a=1.0)
    lattice = ml.create_square_lattice(1.0)
    
    # Create moiré lattice at the specified angle
    moire = ml.py_twisted_bilayer(lattice, math.radians(theta_deg))
    
    # Get moiré lattice vectors
    (L1, L2) = moire.primitive_vectors()
    L1 = np.array([L1[0], L1[1]])  # convert to 2D
    L2 = np.array([L2[0], L2[1]])
    
    # Sample lattice points for both layers
    def generate_lattice_points(a1, a2, radius=30):
        points = []
        for m in range(-radius, radius+1):
            for n in range(-radius, radius+1):
                point = m*a1 + n*a2
                points.append(point)
        return np.array(points)
    
    # Layer 1 points (original lattice)
    P1 = generate_lattice_points(a1, a2, 30)
    
    # Layer 2 points (rotated lattice)  
    R = rotation_matrix(theta_deg)
    a1_rot = R @ a1
    a2_rot = R @ a2
    P2 = generate_lattice_points(a1_rot, a2_rot, 30)
    
    # Create the plot
    plt.figure(figsize=(8,8))
    
    # Plot both lattices
    plt.scatter(P1[:,0], P1[:,1], s=20, alpha=0.6, label="Layer 1")
    plt.scatter(P2[:,0], P2[:,1], s=20, alpha=0.6, label="Layer 2")
    
    # Sample points within moiré unit cell to identify stacking
    n_sample = 20
    xs = np.linspace(0, 1, n_sample)
    ys = np.linspace(0, 1, n_sample)
    
    AA_points = []
    AB_points = []
    BA_points = []
    
    for i, x in enumerate(xs):
        for j, y in enumerate(ys):
            # Position in real space
            pos = x * L1 + y * L2
            pos_3d = [pos[0], pos[1], 0.0]
            
            # Check stacking type using the moiré framework
            stacking = moire.get_stacking_at(pos_3d)
            
            # For square lattice: AA at lattice points, AB/BA at half-lattice points
            # Use simple geometric approach for square lattice
            frac1 = np.array([x, y])  # fractional coordinates in moiré cell
            
            # Check proximity to AA sites (lattice points)
            dist_to_lattice = min(
                np.linalg.norm(frac1 - np.array([0, 0])),
                np.linalg.norm(frac1 - np.array([1, 0])),
                np.linalg.norm(frac1 - np.array([0, 1])),
                np.linalg.norm(frac1 - np.array([1, 1]))
            )
            
            # Check proximity to AB sites (1/2, 0) and (0, 1/2) type points
            dist_to_AB = min(
                np.linalg.norm(frac1 - np.array([0.5, 0])),
                np.linalg.norm(frac1 - np.array([0, 0.5])),
                np.linalg.norm(frac1 - np.array([0.5, 1])),
                np.linalg.norm(frac1 - np.array([1, 0.5]))
            )
            
            # Check proximity to BA sites (1/2, 1/2) type points  
            dist_to_BA = np.linalg.norm(frac1 - np.array([0.5, 0.5]))
            
            tol = 0.15  # tolerance for classification
            if dist_to_lattice < tol:
                AA_points.append(pos)
            elif dist_to_AB < tol:
                AB_points.append(pos)
            elif dist_to_BA < tol:
                BA_points.append(pos)
    
    # Plot stacking regions
    if AA_points:
        AA_points = np.array(AA_points)
        plt.scatter(AA_points[:,0], AA_points[:,1], s=100, c='red', marker='s', 
                   alpha=0.8, label='AA stacking')
    if AB_points:
        AB_points = np.array(AB_points)
        plt.scatter(AB_points[:,0], AB_points[:,1], s=100, c='blue', marker='^', 
                   alpha=0.8, label='AB stacking')
    if BA_points:
        BA_points = np.array(BA_points)
        plt.scatter(BA_points[:,0], BA_points[:,1], s=100, c='green', marker='v', 
                   alpha=0.8, label='BA stacking')
    
    # Outline the moiré unit cell
    moire_cell = np.array([[0,0], L1, L1+L2, L2, [0,0]])
    plt.plot(moire_cell[:,0], moire_cell[:,1], 'k-', linewidth=3, alpha=0.8, 
            label='Moiré unit cell')
    
    plt.gca().set_aspect("equal","box")
    plt.legend()
    plt.title(f"Stacking register @ {theta_deg:.2f}° (Moiré period ≈ {np.linalg.norm(L1):.1f})")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.tight_layout()
    plt.savefig(outfile, dpi=140)
    plt.close()

stack_plot = os.path.join(output_dir, "02_stacking_register_test_angle.png")
plot_stacking_register(cand.a1, cand.a2, theta_test_deg, stack_plot)
stack_plot

'moire_pipeline_outputs/02_stacking_register_test_angle.png'

## 4) Mini moiré unit cells (AA/AB/BA) with sizes 1×, 2×, 3×

In [8]:
def compute_moire_cell_approx(a1: np.ndarray, a2: np.ndarray, theta_deg: float,
                              Nmax: int = 60, tol: float = 1e-6):
    # Integer search approximant for commensurate moiré vectors (fallback if API not available).
    A = np.stack([a1,a2],axis=1)
    R = rotation_matrix(theta_deg); B = R @ A
    best = None; best_err = 1e9
    for m1 in range(-Nmax, Nmax+1):
        for m2 in range(-Nmax, Nmax+1):
            if m1==0 and m2==0: continue
            u = A @ np.array([m1,m2],float)
            for n1 in range(-Nmax, Nmax+1):
                for n2 in range(-Nmax, Nmax+1):
                    if n1==0 and n2==0: continue
                    v = B @ np.array([n1,n2],float)
                    e = np.linalg.norm(u-v)
                    if e < best_err:
                        best_err = e; best = (u,v)
                        if e < tol: break
                if best_err<tol: break
            if best_err<tol: break
        if best_err<tol: break
    if best is None: raise RuntimeError("Failed to find commensurate approximant.")
    # pick u as L1; find L2 by a secondary search encouraging independence
    L1 = best[0]
    best2 = None; best2_err = 1e9
    for m1 in range(-Nmax, Nmax+1):
        for m2 in range(-Nmax, Nmax+1):
            if m1==0 and m2==0: continue
            U = A @ np.array([m1,m2],float)
            # favor non-collinearity
            cross = abs(np.cross([U[0],U[1],0],[L1[0],L1[1],0])[-1])
            if cross < 1e-6: continue
            for n1 in range(-Nmax, Nmax+1):
                for n2 in range(-Nmax, Nmax+1):
                    if n1==0 and n2==0: continue
                    V = B @ np.array([n1,n2],float)
                    e = np.linalg.norm(U - V)
                    if e < best2_err:
                        best2_err = e; best2 = U
    L2 = best2
    return L1, L2

def get_moire_vectors_from_framework(a1: np.ndarray, a2: np.ndarray, theta_deg: float):
    # Use the moire_lattice_py framework to get moiré lattice vectors
    if not ML_AVAILABLE:
        print("Warning: moire_lattice_py not available, using fallback approximation")
        return compute_moire_cell_approx(a1, a2, theta_deg)
    
    # Create base lattice - assume square lattice with unit parameter
    lattice = ml.create_square_lattice(1.0)
    
    # Create moiré lattice at the specified angle
    moire = ml.py_twisted_bilayer(lattice, math.radians(theta_deg))
    
    # Get the moiré lattice primitive vectors 
    (L1_3d, L2_3d) = moire.primitive_vectors()
    
    # Convert to 2D numpy arrays
    L1 = np.array([L1_3d[0], L1_3d[1]], dtype=float)
    L2 = np.array([L2_3d[0], L2_3d[1]], dtype=float)
    
    return L1, L2

def scaled_moire_vectors(L1: np.ndarray, L2: np.ndarray, scale: int):
    return (scale*L1, scale*L2)

def registry_basis_fractional(registry: str):
    """Return two fractional basis positions (u,v) for the given registry.
    For square lattices:
      - AA: (0,0) and (0,0)
      - AB: (0,0) and (0.5, 0.5)
      - BA: (0,0) and (0.5, -0.5)  (wrapped later)
    """
    if registry == "AA":
        return [(0.0, 0.0), (0.0, 0.0)]
    elif registry == "AB":
        return [(0.0, 0.0), (0.5, 0.5)]
    elif registry == "BA":
        return [(0.0, 0.0), (0.5, -0.5)]
    else:
        raise ValueError(f"Unknown registry: {registry}")

def frac_to_cart(a1: np.ndarray, a2: np.ndarray, u: float, v: float) -> np.ndarray:
    return u * a1 + v * a2

def wrap01(x: float) -> float:
    y = x % 1.0
    return y if y >= 0 else y + 1.0

def plot_moire_unit_cell(a1: np.ndarray, a2: np.ndarray, registry: str, theta_deg: float, outfile: str, title: str):
    """Plot moiré unit cell with two-point basis and one surrounding shell.
    Also overlay monolayer lattice points (layer 1 and rotated layer 2) clipped to the cell.
    """
    plt.figure(figsize=(8,8))

    # Draw central cell outline
    cell = np.array([[0,0], a1, a1+a2, a2, [0,0]], float)
    plt.plot(cell[:,0], cell[:,1], 'k-', linewidth=2.0, label='Moiré unit cell')

    # Draw one surrounding shell of cell outlines
    for m in (-1,0,1):
        for n in (-1,0,1):
            if m == 0 and n == 0:
                continue
            t = m*a1 + n*a2
            shell = cell + t
            plt.plot(shell[:,0], shell[:,1], color='gray', alpha=0.25, linewidth=1.0)

    # Basis fractional positions (wrap BA entry into [0,1))
    basis_frac_raw = registry_basis_fractional(registry)
    basis_frac = [(wrap01(u), wrap01(v)) for (u,v) in basis_frac_raw]
    basis_cart = [frac_to_cart(a1, a2, u, v) for (u,v) in basis_frac]

    # Overlay monolayer lattice points from framework within a bounding rectangle
    if ML_AVAILABLE:
        lattice = ml.create_square_lattice(1.0)
        # Bounding rectangle of central cell
        xmin, xmax = cell[:,0].min(), cell[:,0].max()
        ymin, ymax = cell[:,1].min(), cell[:,1].max()
        width = (xmax - xmin) * 1.2
        height = (ymax - ymin) * 1.2
        pts1 = lattice.get_direct_lattice_points_in_rectangle(width, height)
        P1 = np.array([[p[0], p[1]] for p in pts1])
        # Rotate layer 2 by twist angle and translate by registry offset
        R = rotation_matrix(theta_deg)
        P2 = (P1 @ R.T)
        # Registry translation for layer 2
        (u2, v2) = basis_frac[1]
        t2 = frac_to_cart(a1, a2, u2, v2)
        P2 = P2 + t2

        # Point-in-cell test (central cell only)
        def in_cell(pt: np.ndarray) -> bool:
            # Solve for u,v in pt = u*a1 + v*a2
            det = a1[0]*a2[1] - a1[1]*a2[0]
            if abs(det) < 1e-12:
                return False
            u = (pt[0]*a2[1] - pt[1]*a2[0]) / det
            v = (a1[0]*pt[1] - a1[1]*pt[0]) / det
            return (0.0 <= u <= 1.0) and (0.0 <= v <= 1.0)

        L1_in = np.array([p for p in P1 if in_cell(p)])
        L2_in = np.array([p for p in P2 if in_cell(p)])

        if len(L1_in) > 0:
            plt.scatter(L1_in[:,0], L1_in[:,1], s=30, c='tab:red', alpha=0.65, marker='o', label='Layer 1 lattice')
        if len(L2_in) > 0:
            plt.scatter(L2_in[:,0], L2_in[:,1], s=30, c='tab:blue', alpha=0.65, marker='s', label='Layer 2 lattice')

    # Plot basis markers (central cell)
    plt.scatter(basis_cart[0][0], basis_cart[0][1], s=160, c='tab:red', marker='o', edgecolors='k', linewidths=1.0, label='Basis 1 (A)')
    plt.scatter(basis_cart[1][0], basis_cart[1][1], s=160, c='tab:blue', marker='s', edgecolors='k', linewidths=1.0, label='Basis 2 (B)')

    # Axes settings
    plt.gca().set_aspect('equal', 'box')
    pad = 0.15 * max(np.linalg.norm(a1), np.linalg.norm(a2))
    xmin, xmax = cell[:,0].min()-pad, cell[:,0].max()+pad
    ymin, ymax = cell[:,1].min()-pad, cell[:,1].max()+pad
    plt.xlim(xmin, xmax); plt.ylim(ymin, ymax)

    plt.title(title)
    plt.xlabel('x'); plt.ylabel('y')
    plt.legend(loc='best', framealpha=0.85)
    plt.grid(True, alpha=0.25)
    plt.tight_layout()
    plt.savefig(outfile, dpi=160)
    plt.close()

L1, L2 = get_moire_vectors_from_framework(cand.a1, cand.a2, theta_test_deg)
unit_sizes = [1,2,3]
cell_imgs = []
for reg in ["AA","AB","BA"]:
    for sc in unit_sizes:
        l1,l2 = scaled_moire_vectors(L1,L2,sc)
        img_path = os.path.join(output_dir, f"03_cell_{reg}_{sc}x.png")
        plot_moire_unit_cell(l1, l2, reg, theta_test_deg, img_path, f"{reg} moiré cell — {sc}×")
        cell_imgs.append(img_path)
cell_imgs

['moire_pipeline_outputs/03_cell_AA_1x.png',
 'moire_pipeline_outputs/03_cell_AA_2x.png',
 'moire_pipeline_outputs/03_cell_AA_3x.png',
 'moire_pipeline_outputs/03_cell_AB_1x.png',
 'moire_pipeline_outputs/03_cell_AB_2x.png',
 'moire_pipeline_outputs/03_cell_AB_3x.png',
 'moire_pipeline_outputs/03_cell_BA_1x.png',
 'moire_pipeline_outputs/03_cell_BA_2x.png',
 'moire_pipeline_outputs/03_cell_BA_3x.png']

## 5) Band diagrams for 3 registries × 3 sizes (9 total)

In [9]:
def mpb_geometry_for_registry(registry: str, a1: np.ndarray, a2: np.ndarray,
                              hole_radius_a: float, eps_bg: float):
    # Two-point basis per unit cell from stacking:
    # AA: (0,0) + (0,0)
    # AB: (0,0) + (0.5, 0.5)
    # BA: (0,0) + (0.5, -0.5) wrapped to [0,1)
    def wrap01(x: float) -> float:
        y = x % 1.0
        return y if y >= 0 else y + 1.0

    if registry == "AA":
        basis_uv = [(0.0, 0.0), (0.0, 0.0)]
    elif registry == "AB":
        basis_uv = [(0.0, 0.0), (0.5, 0.5)]
    elif registry == "BA":
        basis_uv = [(0.0, 0.0), (wrap01(0.5), wrap01(-0.5))]
    else:
        raise ValueError(f"Unknown registry: {registry}")

    lat = mp_lattice_from_vectors(a1, a2)

    # Build geometry: two air cylinders per unit cell placed at fractional centers
    geom: List[mp.GeometricObject] = []
    for (u, v) in basis_uv:
        center_frac = mp.Vector3(u, v, 0)
        geom.append(mp.Cylinder(radius=hole_radius_a, material=mp.air, center=center_frac))

    mat = mp.Medium(epsilon=eps_bg)
    return lat, geom, mat

def run_bands_for_cell(registry: str, l1: np.ndarray, l2: np.ndarray,
                       lattice_type: str, target_band: int, k_label: str,
                       out_prefix: str):
    lat, geom, mat = mpb_geometry_for_registry(registry, l1, l2, cand.hole_radius_a, cand.eps_bg)
    kpts, k_labels, k_inds = hs_path_for_lattice(lattice_type, kpts_per_seg)
    ms = mpb.ModeSolver(geometry_lattice=lat, geometry=geom, default_material=mat,
                        k_points=kpts, resolution=resolution, num_bands=num_bands, dimensions=2)
    pol = cand.polarization.strip().lower()
    result = {}
    if pol == "tm":
        ms.run_tm(); result["freqs"] = np.array(ms.all_freqs)
    elif pol == "te":
        ms.run_te(); result["freqs"] = np.array(ms.all_freqs)
    else:
        ms.run_tm(); result["freqs_tm"] = np.array(ms.all_freqs)
        ms.reset_meep(); ms = mpb.ModeSolver(geometry_lattice=lat, geometry=geom, default_material=mat,
                                             k_points=kpts, resolution=resolution, num_bands=num_bands, dimensions=2)
        ms.run_te(); result["freqs_te"] = np.array(ms.all_freqs)
    result.update(dict(s=path_arclength(kpts), k_labels=k_labels, k_inds=np.array(k_inds)))
    outf = f"{out_prefix}.png"
    plot_band_diagram(result, lattice_type, target_band, k_label, outf,
                      title_prefix=f"Moiré {registry}")
    return outf

moire_band_plots = []
for reg in ["AA","AB","BA"]:
    for sc in unit_sizes:
        l1,l2 = scaled_moire_vectors(L1,L2,sc)
        outp = os.path.join(output_dir, f"04_bands_{reg}_{sc}x")
        png = run_bands_for_cell(reg, l1, l2, cand.lattice_type, cand.band_index, cand.k_label, outp)
        moire_band_plots.append(png)
moire_band_plots

Initializing eigensolver data
Computing 8 bands with 1e-07 tolerance
Working in 2 dimensions.
Grid size is 40 x 40 x 1.
Solving for 8 bands at a time.
Creating Maxwell data...
Mesh size is 3.
Lattice vectors:
     (-0.0261769, 0.999657, 0)
     (-0.999657, -0.0261769, 0)
     (0, 0, 1)
Cell volume = 1
Reciprocal lattice vectors (/ 2 pi):
     (-0.0261769, 0.999657, 0)
     (-0.999657, -0.0261769, 0)
     (0, -0, 1)
Geometric objects:
     cylinder, center = (0,0,0)
          radius 0.43, height 1e+20, axis (0, 0, 1)
     cylinder, center = (0,0,0)
          radius 0.43, height 1e+20, axis (0, 0, 1)
Geometric object tree has depth 1 and 2 object nodes (vs. 2 actual objects)
Initializing epsilon function...
Allocating fields...
Solving for band polarization: tm.
Initializing fields to random numbers...
58 k-points
  Vector3<0.0, 0.0, 0.0>
  Vector3<0.02631578947368421, 0.0, 0.0>
  Vector3<0.05263157894736842, 0.0, 0.0>
  Vector3<0.07894736842105263, 0.0, 0.0>
  Vector3<0.1052631578947368

['moire_pipeline_outputs/04_bands_AA_1x.png',
 'moire_pipeline_outputs/04_bands_AA_2x.png',
 'moire_pipeline_outputs/04_bands_AA_3x.png',
 'moire_pipeline_outputs/04_bands_AB_1x.png',
 'moire_pipeline_outputs/04_bands_AB_2x.png',
 'moire_pipeline_outputs/04_bands_AB_3x.png',
 'moire_pipeline_outputs/04_bands_BA_1x.png',
 'moire_pipeline_outputs/04_bands_BA_2x.png',
 'moire_pipeline_outputs/04_bands_BA_3x.png']

### Collate the 9 band plots into a 3×3 montage image

In [10]:

def make_montage(images: List[str], rows: int, cols: int, outfile: str, pad: int = 10, bg=255):
    assert len(images) == rows*cols
    ims = [Image.open(p).convert("RGB") for p in images]
    w, h = ims[0].size
    W = cols*w + (cols+1)*pad
    H = rows*h + (rows+1)*pad
    canvas = Image.new("RGB", (W,H), (bg,bg,bg))
    idx = 0
    for r in range(rows):
        for c in range(cols):
            x = pad + c*(w+pad); y = pad + r*(h+pad)
            canvas.paste(ims[idx], (x,y)); idx += 1
    canvas.save(outfile)
    return outfile

# Order: rows = AA/AB/BA, cols = 1x/2x/3x
grid_order = []
for reg in ["AA","AB","BA"]:
    for sc in [1,2,3]:
        grid_order.append(os.path.join(output_dir, f"04_bands_{reg}_{sc}x.png"))

montage_path = os.path.join(output_dir, "05_moire_bands_montage_3x3.png")
make_montage(grid_order, rows=3, cols=3, outfile=montage_path)
montage_path


'moire_pipeline_outputs/05_moire_bands_montage_3x3.png'

## 6) Optimize moiré angle and produce final AA/AB/BA band diagrams

In [11]:
# Envelope model helpers: monolayer mass, registry edges, potential, solver

from typing import Dict, Tuple

def monolayer_geometry():
    lat = mp_lattice_from_vectors(cand.a1, cand.a2)
    # Single air hole at (0,0) in dielectric background
    geom = [mp.Cylinder(radius=cand.hole_radius_a, material=mp.air, center=mp.Vector3(0,0,0))]
    mat = mp.Medium(epsilon=cand.eps_bg)
    return lat, geom, mat


def k0_from_label(lattice_type: str, k_label: str):
    kpts, labels, inds = hs_path_for_lattice(lattice_type, kpts_per_seg)
    if k_label in labels:
        idx = labels.index(k_label)
        kp = kpts[inds[idx]]
        return kp
    # Fallback to Gamma
    return mp.Vector3(0, 0)


def run_mpb_single_k(lat, geom, mat, kpoint, num_bands, pol: str):
    ms = mpb.ModeSolver(geometry_lattice=lat, geometry=geom, default_material=mat,
                        k_points=[kpoint], resolution=resolution, num_bands=num_bands, dimensions=2)
    pol = pol.strip().lower()
    if pol == "tm":
        ms.run_tm(); freqs = np.array(ms.all_freqs)[0]
    elif pol == "te":
        ms.run_te(); freqs = np.array(ms.all_freqs)[0]
    else:
        ms.run_tm(); freqs = np.array(ms.all_freqs)[0]
    return freqs


def monolayer_band_edge_and_curvature(delta: float = 0.02) -> Tuple[float, mp.Vector3, float]:
    """
    Returns (omega0, k0, alpha), where alpha is the curvature coefficient used in
    the envelope operator as -alpha * ∇^2. We approximate isotropic alpha from
    second derivatives along fractional kx and ky directions.
    """
    lat, geom, mat = monolayer_geometry()
    k0 = k0_from_label(cand.lattice_type, cand.k_label)
    # Offsets in fractional coordinates
    ex = mp.Vector3(1.0, 0.0)
    ey = mp.Vector3(0.0, 1.0)
    kx_p = k0 + delta * ex
    kx_m = k0 - delta * ex
    ky_p = k0 + delta * ey
    ky_m = k0 - delta * ey
    # Frequencies at points
    f0 = run_mpb_single_k(lat, geom, mat, k0, num_bands, cand.polarization)[cand.band_index]
    fxp = run_mpb_single_k(lat, geom, mat, kx_p, num_bands, cand.polarization)[cand.band_index]
    fxm = run_mpb_single_k(lat, geom, mat, kx_m, num_bands, cand.polarization)[cand.band_index]
    fyp = run_mpb_single_k(lat, geom, mat, ky_p, num_bands, cand.polarization)[cand.band_index]
    fym = run_mpb_single_k(lat, geom, mat, ky_m, num_bands, cand.polarization)[cand.band_index]
    # Second derivatives (central)
    d2x = (fxp + fxm - 2*f0) / (delta**2)
    d2y = (fyp + fym - 2*f0) / (delta**2)
    alpha = 0.5 * (d2x + d2y)
    return f0, k0, alpha


def registry_edge_frequencies(k0) -> Dict[str, float]:
    """Compute registry band-edge frequencies at the same k0 for AA/AB/BA using the small-cell registry geometries."""
    edges: Dict[str, float] = {}
    for reg in ["AA", "AB", "BA"]:
        lat = mp_lattice_from_vectors(cand.a1, cand.a2)
        # reuse registry geometry builder on base (L1) cell
        lat_r, geom_r, mat_r = mpb_geometry_for_registry(reg, cand.a1, cand.a2, cand.hole_radius_a, cand.eps_bg)
        # Force lattice from cand vectors to align sizes
        ms = mpb.ModeSolver(geometry_lattice=lat, geometry=geom_r, default_material=mat_r,
                            k_points=[k0], resolution=resolution, num_bands=num_bands, dimensions=2)
        pol = cand.polarization.strip().lower()
        if pol == "tm":
            ms.run_tm(); freqs = np.array(ms.all_freqs)[0]
        elif pol == "te":
            ms.run_te(); freqs = np.array(ms.all_freqs)[0]
        else:
            ms.run_tm(); freqs = np.array(ms.all_freqs)[0]
        edges[reg] = freqs[cand.band_index]
    return edges


def moire_vectors(theta_deg: float) -> Tuple[np.ndarray, np.ndarray, object]:
    """Return moiré L1,L2 (2D) and moire object for stacking queries."""
    if not ML_AVAILABLE:
        L1m, L2m = get_moire_vectors_from_framework(cand.a1, cand.a2, theta_deg)
        return L1m, L2m, None
    base = ml.create_square_lattice(1.0)
    moire = ml.py_twisted_bilayer(base, math.radians(theta_deg))
    (v1, v2) = moire.primitive_vectors()
    L1m = np.array([v1[0], v1[1]], float)
    L2m = np.array([v2[0], v2[1]], float)
    return L1m, L2m, moire


def stacking_type_at(moire_obj, pos_xy: Tuple[float, float]) -> str:
    if moire_obj is None:
        # Fallback: simple proximity classification in fractional coords
        return "AA"
    st = moire_obj.get_stacking_at([pos_xy[0], pos_xy[1], 0.0])
    if st is None:
        return "AA"
    s = str(st).upper()
    if s in ("AA", "AB", "BA"):
        return s
    if s == "A":
        return "AB"
    if s == "B":
        return "BA"
    return "AA"


def build_moire_potential(theta_deg: float, edges: Dict[str, float], omega0: float, N: int = 64) -> Tuple[np.ndarray, np.ndarray, np.ndarray, object]:
    """Return (U,V, Vgrid, moire_obj) where U,V are grids in [0,1), Vgrid is potential array shape (N,N)."""
    L1m, L2m, moire_obj = moire_vectors(theta_deg)
    U = np.linspace(0.0, 1.0, N, endpoint=False)
    V = np.linspace(0.0, 1.0, N, endpoint=False)
    Vgrid = np.zeros((N, N), float)
    for i, u in enumerate(U):
        for j, v in enumerate(V):
            pos = u * L1m + v * L2m
            reg = stacking_type_at(moire_obj, (pos[0], pos[1]))
            Vgrid[i, j] = edges.get(reg, edges["AA"]) - omega0
    return U, V, Vgrid, moire_obj


def solve_envelope(Vgrid: np.ndarray, alpha: float, bc: str = "dirichlet") -> Tuple[float, np.ndarray]:
    """Solve (-alpha ∇^2 + V) F = lambda F on a unit-square grid using FD.
    Returns (lambda0, F0) for the ground state.
    """
    n = Vgrid.shape[0]
    N = n * n
    h = 1.0 / n
    try:
        import scipy.sparse as sp
        import scipy.sparse.linalg as spla
    except Exception as e:
        print("SciPy not available for sparse solve, using dense fallback (n<=48 recommended):", e)
        # Dense fallback with 5-point Laplacian scaled by 1/h^2
        def idx(i, j): return i*n + j
        L = np.zeros((N, N), float)
        for i in range(n):
            for j in range(n):
                k = idx(i, j)
                L[k, k] = 4.0
                for (ii, jj) in ((i-1,j),(i+1,j),(i,j-1),(i,j+1)):
                    if 0 <= ii < n and 0 <= jj < n:
                        L[k, idx(ii, jj)] = -1.0
        L *= (1.0 / (h*h))
        H = alpha * L + np.diag(Vgrid.reshape(N))
        w, v = np.linalg.eigh(H)
        return float(w[0]), v[:, 0].reshape(n, n)

    # Sparse path
    main = np.full(N, 4.0)
    off1 = np.full(N-1, -1.0)
    offn = np.full(N-n, -1.0)
    for k in range(1, n):
        off1[k*n - 1] = 0.0  # block boundaries
    diags = [main, off1, off1, offn, offn]
    offsets = [0, -1, 1, -n, n]
    L = sp.diags(diags, offsets, shape=(N, N), format='csr')
    L = L * (1.0 / (h*h))
    H = alpha * L + sp.diags(Vgrid.reshape(N), 0, format='csr')
    vals, vecs = spla.eigsh(H, k=1, which='SA')
    F0 = vecs[:, 0].reshape(n, n)
    return float(vals[0]), F0


def objective_for_theta(theta_deg: float, omega0: float, alpha: float, edges: Dict[str, float], N: int = 64) -> Tuple[float, float]:
    U, V, Vgrid, _ = build_moire_potential(theta_deg, edges, omega0, N)
    lam0, _ = solve_envelope(Vgrid, alpha)
    # Also compute potential contrast for reference
    contrast = float(Vgrid.max() - Vgrid.min())
    return lam0, contrast

In [12]:
# Angle optimization and final outputs

# 1) Monolayer curvature and edge at k0
if not MEEP_AVAILABLE:
    raise RuntimeError("Meep/MPB not available; cannot run envelope optimization.")

omega0, k0, alpha = monolayer_band_edge_and_curvature(delta=0.02)
print(f"Monolayer target band edge ω0={omega0:.6f} at k0=({k0.x:.3f},{k0.y:.3f}), alpha={alpha:.6e}")

# 2) Registry band-edge shifts at same k0
edges = registry_edge_frequencies(k0)
print("Registry edges at k0:", {k: f"{v:.6f}" for k, v in edges.items()})

# 3) Angle scan and objective (ground-state envelope eigenvalue)
angle_range = np.linspace(0.5, 5.0, 19)  # degrees
objs = []
contrasts = []
for th in angle_range:
    lam0, contrast = objective_for_theta(th, omega0, alpha, edges, N=64)
    objs.append(lam0); contrasts.append(contrast)

objs = np.array(objs); contrasts = np.array(contrasts)
opt_idx = int(np.argmin(objs))
opt_theta = float(angle_range[opt_idx])
print(f"Optimal angle ≈ {opt_theta:.3f}° (min envelope eigenvalue {objs[opt_idx]:.6f})")

# 4) Plot objective and contrast vs angle
plt.figure(figsize=(6,4))
plt.plot(angle_range, objs, '-o', ms=3)
plt.xlabel('Twist angle (deg)'); plt.ylabel('Ground-state λ')
plt.title('Envelope objective vs angle')
plt.tight_layout(); plt.savefig(os.path.join(output_dir, '06_objective_vs_angle.png'), dpi=160); plt.close()

plt.figure(figsize=(6,4))
plt.plot(angle_range, contrasts, '-o', ms=3)
plt.xlabel('Twist angle (deg)'); plt.ylabel('Potential contrast (ω units)')
plt.title('Moiré potential contrast vs angle')
plt.tight_layout(); plt.savefig(os.path.join(output_dir, '06_contrast_vs_angle.png'), dpi=160); plt.close()

# 5) Build and plot final moiré lattice (direct and reciprocal)
L1m, L2m, moire_obj = moire_vectors(opt_theta)

# Direct space plot
plt.figure(figsize=(6,6))
cell = np.array([[0,0], L1m, L1m+L2m, L2m, [0,0]], float)
plt.plot(cell[:,0], cell[:,1], 'k-', lw=2)
for m in (-1,0,1):
    for n in (-1,0,1):
        if m==0 and n==0: continue
        t = m*L1m + n*L2m
        shell = cell + t
        plt.plot(shell[:,0], shell[:,1], color='0.75', lw=1)
plt.gca().set_aspect('equal','box')
plt.title(f"Final moiré lattice (direct) @ {opt_theta:.3f}°")
plt.tight_layout(); plt.savefig(os.path.join(output_dir, '07_final_moire_direct.png'), dpi=160); plt.close()

# Reciprocal space plot
if ML_AVAILABLE and moire_obj is not None:
    L1_py, L2_py = moire_obj.lattice_1(), moire_obj.lattice_2()
    try:
        ax = plot_reciprocal_registry(L1_py, L2_py, factor=1)
        ax.figure.savefig(os.path.join(output_dir, '07_final_moire_reciprocal.png'), dpi=160)
        plt.close(ax.figure)
    except Exception as e:
        print("Reciprocal plotting failed:", e)

# 6) Band diagrams for x1 mini unit cells (AA/AB/BA)
# Use moiré vectors at optimal angle for geometry scaling
L1_opt, L2_opt = L1m, L2m

final_band_plots = []
for reg in ["AA", "AB", "BA"]:
    outp = os.path.join(output_dir, f"08_final_bands_{reg}_1x")
    png = run_bands_for_cell(reg, L1_opt, L2_opt, cand.lattice_type, cand.band_index, cand.k_label, outp)
    final_band_plots.append(png)

final_band_plots


Initializing eigensolver data
Computing 8 bands with 1e-07 tolerance
Working in 2 dimensions.
Grid size is 40 x 40 x 1.
Solving for 8 bands at a time.
Creating Maxwell data...
Mesh size is 3.
Lattice vectors:
     (1, 0, 0)
     (0, 1, 0)
     (0, 0, 1)
Cell volume = 1
Reciprocal lattice vectors (/ 2 pi):
     (1, -0, 0)
     (-0, 1, -0)
     (0, -0, 1)
Geometric objects:
     cylinder, center = (0,0,0)
          radius 0.43, height 1e+20, axis (0, 0, 1)
Geometric object tree has depth 1 and 1 object nodes (vs. 1 actual objects)
Initializing epsilon function...
Allocating fields...
Solving for band polarization: tm.
Initializing fields to random numbers...
1 k-points
  Vector3<0.5, 0.5, 0.0>
elapsed time for initialization: 0.012264251708984375
solve_kpoint (0.5,0.5,0):
tmfreqs:, k index, k1, k2, k3, kmag/2pi, tm band 1, tm band 2, tm band 3, tm band 4, tm band 5, tm band 6, tm band 7, tm band 8
Solving for bands 1 to 8...
    near maximum in trace
    iteration    1: trace = 15.321045

SystemError: <built-in function mode_solver_solve_kpoint> returned a result with an exception set

## Reciprocal-space unit-cell approximation (registry-based)

We switch to reciprocal space for registry visualization and basis selection, per your spec:
- Use the reciprocal unit cell of lattice 1 (or a scaled rectangle enclosing its reciprocal cell for mini-BZ variants).
- Query reciprocal lattice points of lattice 2 within that rectangle using the Python binding `get_reciprocal_lattice_points_in_rectangle`.
- Fold those points into lattice 1's reciprocal cell and treat the non-integer sites as “strays.”
- Define the MPB real-space basis as: one lattice-1 site at (0,0) plus one additional site per stray (usually one for factor=1). The number of strays scales with the mini-BZ factor.

Notes:
- All registry/stacking visualizations below are in reciprocal space.
- Geometry in MPB remains a real-space lattice; we only use the reciprocal stray count to determine how many basis sites to include. If you prefer a different mapping from reciprocal strays to real-space offsets, say so and we will change it.

Usage (expects Python bindings objects):
- `L1_py = moire.lattice_1()` and `L2_py = moire.lattice_2()` from a built `PyMoire2D`.
- Call plotting and basis helpers below with `L1_py`, `L2_py`, and `factor` (1 for mini-bz 1, 2 for mini-bz 2, ...).

In [None]:
# Visualize envelope potential and ground state at the optimal angle
N_vis = 96
U, V, Vgrid, _ = build_moire_potential(opt_theta, edges, omega0, N=N_vis)
lam0_opt, F0 = solve_envelope(Vgrid, alpha)
print(f"Optimal envelope ground-state λ0 = {lam0_opt:.6f}")

# Normalize F0 for plotting
F0_abs = np.abs(F0)
F0_abs /= (F0_abs.max() + 1e-12)

plt.figure(figsize=(5,4))
plt.imshow(Vgrid.T, origin='lower', extent=[0,1,0,1], cmap='viridis', aspect='equal')
plt.colorbar(label='V (ω units)')
plt.title('Moiré potential V(u,v) at optimal angle')
plt.tight_layout(); plt.savefig(os.path.join(output_dir, '06_potential_optimal.png'), dpi=160); plt.close()

plt.figure(figsize=(5,4))
plt.imshow(F0_abs.T, origin='lower', extent=[0,1,0,1], cmap='magma', aspect='equal')
plt.colorbar(label='|F| (normed)')
plt.title('Envelope ground-state |F(u,v)| at optimal angle')
plt.tight_layout(); plt.savefig(os.path.join(output_dir, '06_envelope_groundstate_optimal.png'), dpi=160); plt.close()