In [1]:
import os
import random
import numpy as np
import xml.etree.ElementTree as ET
from typing import List


def export_thorimage_slm(
    final_clusters: List[List[int]],
    xpix: np.ndarray,
    ypix: np.ndarray,
    out_dir: str,
    prefix: str = "stim",
    spiraldia: float = 8,
    umperpix: float = 1.0,
    spiral_dur: int = 10,
    iterations: int = 10,
    power: float = 13,
    post_spiral: int = 0,
    runs: int = 1,
    limit_patterns: int = 250,
) -> None:
    """
    Export stimulation clusters to ThorImageSLM-compatible XML files.

    Parameters
    ----------
    final_clusters : list[list[int]]
        List of stimulation clusters (each a list of cell indices).
    xpix, ypix : np.ndarray
        Cell coordinates (pixels).
    out_dir : str
        Directory to save XML and .npy files.
    prefix : str
        Prefix for output files.
    spiraldia : float
        Diameter of stimulation spiral in microns.
    umperpix : float
        Microns per pixel scaling factor.
    spiral_dur : int
        Spiral duration in ms.
    iterations : int
        Iterations per pattern.
    power : float
        Laser power setting (mW).
    post_spiral : int
        Idle time after pattern in ms.
    runs : int
        Number of repeated runs (chains).
    limit_patterns : int
        Maximum number of patterns per XML (ThorImage limit is usually 250).
    """
    os.makedirs(out_dir, exist_ok=True)

    block = len(final_clusters)
    chain = np.arange(0, (block * runs) + block + block, block)

    # Save x/y coords once
    np.save(os.path.join(out_dir, "xs.npy"), np.array(xpix))
    np.save(os.path.join(out_dir, "ys.npy"), np.array(ypix))

    for pdx in range(runs):
        ThorImageSLM = ET.Element("ThorImageSLM")
        SLMPatterns = ET.SubElement(ThorImageSLM, "SLMPatterns")
        SLMSequences = ET.SubElement(ThorImageSLM, "SLMSequences")

        # --- define STIM PATTERNS ---
        for fdx, f in enumerate(final_clusters[:limit_patterns]):
            Pattern = ET.SubElement(SLMPatterns, "Pattern")
            Pattern.set("name", f"PatA{1+fdx}")
            Pattern.set("patternID", f"{1+fdx}")
            Pattern.set("shape", "Ellipse")
            Pattern.set("roiWidthPx", f"{spiraldia / umperpix:.2f}")
            Pattern.set("roiHeightPx", f"{spiraldia / umperpix:.2f}")
            Pattern.set("red", str(random.randint(0, 255)))
            Pattern.set("green", str(random.randint(0, 255)))
            Pattern.set("blue", str(random.randint(0, 255)))
            Pattern.set("pxSpacing", "1")
            Pattern.set("durationMS", str(spiral_dur))
            Pattern.set("iterations", str(iterations))
            Pattern.set("power", str(power))
            Pattern.set("prePatIdleMS", "0")
            Pattern.set("postPatIdleMS", str(post_spiral))
            Pattern.set("preIteIdleMS", "0")
            Pattern.set("postIteIdleMS", "0")
            Pattern.set("measurePowerMW", "0")
            Pattern.set("measurePowerMWPerUM2", "0")

            # Add ROIs (duplicate f[0] for galvo center + all cells)
            for udx, u in enumerate([f[0]] + f):
                ROI = ET.SubElement(Pattern, "ROI")
                ROI.set("subID", f"{1+udx}")
                ROI.set("centerX", str(xpix[u]))
                ROI.set("centerY", str(ypix[u]))

        # --- define SEQUENCES ---
        for idx in range(min(block, limit_patterns)):
            SequenceEpoch = ET.SubElement(SLMSequences, "SequenceEpoch")
            SequenceEpoch.set("sequenceID", f"{idx+1}")
            SequenceEpoch.set("sequence", f"{idx+1}")
            SequenceEpoch.set("sequenceEpochCount", "1")

        # Save cluster list for this run
        np.save(
            os.path.join(out_dir, f"{prefix}_final_{pdx}.npy"),
            np.array(final_clusters[chain[pdx]:chain[pdx+1]]),
        )

        # Write XML file
        xml_str = ET.tostring(ThorImageSLM, encoding="utf-8").decode("utf-8")
        xml_path = os.path.join(out_dir, f"{prefix}_part_{pdx+1}.xml")
        with open(xml_path, "w", encoding="utf-8") as f:
            f.write(xml_str)

        print(f"Run {pdx+1}: wrote {xml_path}")
