GUI for the following tasks: 
1. Create circuit $U\ket{0}$ and get $f(\alpha) = \braket{0|U^\dagger P U|0}$ 
2. Get the circuit output after applying the Fourier Coefficient Extraction algorithm to $f(\alpha)$

# Part 1

## GUI Code

In [None]:
import pennylane as qml
from pennylane import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

gates = []
observable_terms = []
n_qubits = 1
dim_x = 1
dim_alpha = 1

# ---------------- Top-level widgets ----------------
dim_x_widget = widgets.BoundedIntText(value=1, min=0, max=10, description="dim(x):")
dim_alpha_widget = widgets.BoundedIntText(value=1, min=0, max=10, description="dim(alpha):")
n_qubits_widget = widgets.BoundedIntText(value=1, min=1, max=10, description="Qubits:")
apply_dims_button = widgets.Button(description="Apply", button_style="info")
dims_out = widgets.Output()

# ---------------- Circuit builder widgets ----------------
gate_category = widgets.ToggleButtons(options=["Fixed", "Encoding"], description="Gate Type:", disabled=True)

fixed_gate_widget = widgets.Dropdown(description="Fixed Gate:", disabled=True)
encoding_gate_widget = widgets.Dropdown(description="Encoding Gate:", disabled=True)
encoding_gate_widget.layout.display = "none"

target_qubits_widget = widgets.Text(value="0", description="Target(s):", disabled=True)

add_gate_button = widgets.Button(description="Add Gate", button_style="success", disabled=True)
remove_gate_widget = widgets.Dropdown(description="Remove Gate:", disabled=True)
remove_gate_button = widgets.Button(description="Remove", button_style="danger", disabled=True)

gate_display = widgets.Output()

observable_input = widgets.Text(value="Z0", description="Observable:", disabled=True)
add_obs_button = widgets.Button(description="Add Term", button_style="info", disabled=True)

remove_obs_widget = widgets.Dropdown(description="Remove Obs:", disabled=True)
remove_obs_button = widgets.Button(description="Remove", button_style="danger", disabled=True)

obs_display = widgets.Output()

display_mode = widgets.RadioButtons(
    options=["U", "U P U†"], value="U P U†", description="Display:", disabled=True
)

show_circuit_button = widgets.Button(description="Show Circuit", button_style="warning", disabled=True)
clear_button = widgets.Button(description="Clear All", button_style="danger", disabled=True)

output_display = widgets.Output()


# ---------------- Helper functions ----------------
def show_code_block(code_text):
    html = f"""
    <pre style="background:#f5f5f5; padding:10px; border:1px solid #ccc; font-size:14px;">
{code_text}
    </pre>
    <button onclick="navigator.clipboard.writeText(`{code_text}`)">Copy to Clipboard</button>
    """
    display(widgets.HTML(html))


def update_gate_options():
    fixed_ops = ["X", "Y", "Z", "H", "CNOT"]
    for i in range(dim_x):
        fixed_ops += [f"Rx(x[{i}])", f"Ry(x[{i}])", f"Rz(x[{i}])"]

    enc_ops = []
    for j in range(dim_alpha):
        enc_ops += [f"Rx(α[{j}])", f"Ry(α[{j}])", f"Rz(α[{j}])"]

    fixed_gate_widget.options = fixed_ops
    encoding_gate_widget.options = enc_ops


def refresh_gate_display():
    with gate_display:
        clear_output()
        print("Current Gates:")
        for idx, g in enumerate(gates):
            print(f"{idx}: {g}")
    remove_gate_widget.options = [f"{i}: {g['type']} {g['targets']}" for i, g in enumerate(gates)]


def refresh_obs_display():
    with obs_display:
        clear_output()
        print("Current Observable Terms:")
        for idx, o in enumerate(observable_terms):
            print(f"{idx}: {o}")
    remove_obs_widget.options = [f"{i}: {o}" for i, o in enumerate(observable_terms)]


# ---------------- Event handlers ----------------
def apply_dims(b):
    global dim_x, dim_alpha, n_qubits
    dim_x = dim_x_widget.value
    dim_alpha = dim_alpha_widget.value
    n_qubits = n_qubits_widget.value
    update_gate_options()

    for w in [gate_category, fixed_gate_widget, encoding_gate_widget,
              target_qubits_widget, add_gate_button, remove_gate_widget, remove_gate_button,
              observable_input, add_obs_button, remove_obs_widget, remove_obs_button,
              display_mode, show_circuit_button, clear_button]:
        w.disabled = False

    with dims_out:
        clear_output()
        print(f"Applied: dim(x)={dim_x}, dim(alpha)={dim_alpha}, qubits={n_qubits}")


def update_gate_dropdown(change):
    if gate_category.value == "Fixed":
        fixed_gate_widget.layout.display = "block"
        encoding_gate_widget.layout.display = "none"
    else:
        fixed_gate_widget.layout.display = "none"
        encoding_gate_widget.layout.display = "block"


def add_gate_callback(b):
    gate = fixed_gate_widget.value if gate_category.value == "Fixed" else encoding_gate_widget.value

    targets = [int(t) for t in target_qubits_widget.value.split()]
    for t in targets:
        if t < 0 or t >= n_qubits:
            with gate_display:
                print(f"❌ Error: target qubit {t} is out of range. Max={n_qubits - 1}")
            return

    gates.append({"type": gate, "targets": targets})
    refresh_gate_display()


def remove_gate_callback(b):
    if len(gates) > 0:
        idx = int(remove_gate_widget.value.split(":")[0])
        gates.pop(idx)
        refresh_gate_display()


def add_obs_callback(b):
    observable_terms.append(observable_input.value)
    refresh_obs_display()


def remove_obs_callback(b):
    if len(observable_terms) > 0:
        idx = int(remove_obs_widget.value.split(":")[0])
        observable_terms.pop(idx)
        refresh_obs_display()


def clear_all(b):
    gates.clear()
    observable_terms.clear()
    refresh_gate_display()
    refresh_obs_display()
    with output_display: clear_output()


def build_P(observable_terms):
    """
    Builds a Pennylane observable P from a list of strings like:
      ["Z0", "X0 Y1", "X0 Z1 + Y2"]
    Returns a valid qml.Observable object.
    """

    P = None

    for term in observable_terms:
        # Remove plus and treat "X0 + Z1" same as separate terms
        pieces = term.replace("+", " ").split()
        local = None

        for p in pieces:
            op = p[0].upper()      # 'X','Y','Z'
            idx = int(p[1:])       # wire index

            t = getattr(qml, f"Pauli{op}")(idx)

            # Tensor product if needed
            local = t if local is None else local @ t

        # Sum if multiple terms
        P = local if P is None else P + local

    # Default: PauliZ(0) if nothing provided
    return P if P is not None else qml.PauliZ(0)



def apply_gate(g, x, α):
    op = g["type"]
    wires = g["targets"]

    if op in ["X", "Y", "Z"]:
        getattr(qml, f"Pauli{op}")(wires=wires[0])
    elif op == "H":
        qml.Hadamard(wires=wires[0])
    elif op == "CNOT":
        qml.CNOT(wires=wires)
    elif "α[" in op:
        axis = op[1]
        idx = int(op[op.index("[")+1:op.index("]")])
        angle =  α[idx]

        if axis.lower() == "x":
            # Rx(α) = H RZ H
            qml.Hadamard(wires=wires[0])
            qml.RZ(angle, wires=wires[0])
            qml.Hadamard(wires=wires[0])
        elif axis.lower() == "y":
            # Ry(α) = S† H RZ H S
            qml.adjoint(qml.S(wires=wires[0]))
            qml.Hadamard(wires=wires[0])
            qml.RZ(angle, wires=wires[0])
            qml.Hadamard(wires=wires[0])
            qml.S(wires=wires[0])
        else:  # Rz stays as is
            qml.RZ(angle, wires=wires[0])

    elif "x[" in op:
        axis = op[1]
        idx = int(op[op.index("[")+1:op.index("]")])
        angle =  x[idx]
        getattr(qml, f"R{axis.upper()}")(angle, wires=wires[0])




def build_gate_string_export(g, inverse=False):
    """
    Returns a string representing the gate in code.
    Only decompose Rx/Ry if it is an encoding gate (depends on α).
    """
    name, t = g["type"], g["targets"]
    prefix = "-" if inverse else ""

    if name in ["X", "Y", "Z"]:
        return f"Pauli{name}(wires={t[0]})"
    if name == "H":
        return f"Hadamard(wires={t[0]})"
    if name == "CNOT":
        return f"CNOT(wires={t})"

    # Extract axis and index
    axis = name[1]
    idx = int(name[name.index('[')+1:name.index(']')])

    # Only decompose encoding gates (α)
    if "α[" in name:
        param = f"{prefix}α[{idx}]"
        if axis.lower() == "x":
            return f"H(wires={t[0]}); qml.RZ({param}, wires={t[0]}); qml.H(wires={t[0]}) # RX({param})"
        elif axis.lower() == "y":
            return f"RZ(-np.pi/2.0, wires={t[0]}); qml.H(wires={t[0]}); qml.RZ({param}, wires={t[0]}); qml.H(wires={t[0]}); qml.RZ(np.pi/2.0, wires={t[0]}) # RY({param})"
        else:  # Rz stays as is
            return f"RZ({param}, wires={t[0]})"

    # Leave x-dependent gates unchanged
    if "x[" in name:
        param = f"{prefix}x[{idx}]"
        return f"R{axis.upper()}({param}, wires={t[0]})"

    # Fallback (shouldn't reach here)
    return name


def show_circuit(b):
    dev = qml.device("default.qubit", wires=n_qubits)
    P = build_P(observable_terms)

    @qml.qnode(dev)
    def U(x=None, α=None):
        if x is None: x = np.zeros(dim_x)
        if α is None: α = np.zeros(dim_alpha)
        for g in gates:
            apply_gate(g, x, α)
        return qml.state()

    @qml.qnode(dev)
    def UPU(x=None, α=None):
        if x is None: x = np.zeros(dim_x)
        if α is None: α = np.zeros(dim_alpha)
        for g in gates:
            apply_gate(g, x, α)
        qml.apply(P)
        for g in reversed(gates):
            apply_gate_inverse(g, x, α)
        return qml.state()

    x0, α0 = np.zeros(dim_x), np.zeros(dim_alpha)

    with output_display:
        clear_output()

        if display_mode.value == "U":
            qml.draw_mpl(U)(x0, α0)
            plt.title("U")
        else:
            qml.draw_mpl(UPU)(x0, α0)
            plt.title("U P U†")
        plt.show()

        print("\n============================")
        print(" Code to recreate this circuit")
        print("============================\n")

        code = ""
        # --- Imports ---
        code += "import pennylane as qml\n"
        code += "import numpy as np\n\n"

        # --- Dimensions and observables ---
        code += f"dim_x = {dim_x}\n"
        code += f"dim_alpha = {dim_alpha}\n"
        code += f"observable_terms = {observable_terms}\n\n"

        # --- build_P function ---
        code += "def build_P(observable_terms):\n"
        code += "    import pennylane as qml\n"
        code += "    P = None\n"
        code += "    for term in observable_terms:\n"
        code += "        pieces = term.replace('+',' ').split()\n"
        code += "        local = None\n"
        code += "        for p in pieces:\n"
        code += "            op = p[0].upper()\n"
        code += "            idx = int(p[1:])\n"
        code += "            t = getattr(qml, f'Pauli{op}')(idx)\n"
        code += "            local = t if local is None else local @ t\n"
        code += "        P = local if P is None else P + local\n"
        code += "    return P if P is not None else qml.PauliZ(0)\n\n"

        # --- Device ---
        code += f"dev = qml.device('default.qubit', wires={n_qubits})\n\n"

        # --- QNode U ---
        code += "@qml.qnode(dev)\n"
        code += "def U(x, α):\n"
        for g in gates:
            code += f"    qml.{build_gate_string_export(g)}\n"
        code += "    return qml.state()\n\n"

        # --- QNode UPU ---
        if display_mode.value == "U P U†":
            code += "@qml.qnode(dev)\n"
            code += "def UPU(x, α):\n"
            for g in gates:
                code += f"    qml.{build_gate_string_export(g)}\n"
            code += "    P = build_P(observable_terms)\n"
            for g in reversed(gates):
                code += f"    qml.{build_gate_string_export(g, inverse=True)}\n"
            code += "    return qml.state()\n\n"

        # --- Example execution ---
        code += "x0 = np.zeros(dim_x)\n"
        code += "α0 = np.zeros(dim_alpha)\n"
        code += "qml.draw_mpl(U)(x0, α0);\n"
        if display_mode.value == "U P U†":
            code += "qml.draw_mpl(UPU)(x0, α0);\n"

        show_code_block(code)


def apply_gate_inverse(g, x, α):
    op = g["type"]
    wires = g["targets"]

    if op in ["X", "Y", "Z"]:
        getattr(qml, f"Pauli{op}")(wires=wires[0])
    elif op == "H":
        qml.Hadamard(wires=wires[0])
    elif op == "CNOT":
        qml.CNOT(wires=wires)
    elif "α[" in op:
        axis = op[1]
        idx = int(op[op.index("[")+1:op.index("]")])
        angle = - α[idx]

        if axis.lower() == "x":
            # Rx†(α) = H RZ(-θ) H
            qml.Hadamard(wires=wires[0])
            qml.RZ(angle, wires=wires[0])
            qml.Hadamard(wires=wires[0])
        elif axis.lower() == "y":
            # Ry†(α) = S† H RZ(-θ) H S
            qml.adjoint(qml.S(wires=wires[0]))
            qml.Hadamard(wires=wires[0])
            qml.RZ(angle, wires=wires[0])
            qml.Hadamard(wires=wires[0])
            qml.S(wires=wires[0])
        else:
            qml.RZ(angle, wires=wires[0])

    elif "x[" in op:
        axis = op[1]
        idx = int(op[op.index("[")+1:op.index("]")])
        angle = - x[idx]
        getattr(qml, f"R{axis.upper()}")(angle, wires=wires[0])




# ---------------- Layout ----------------
display(widgets.HTML("<h3>Select dims and qubits first:</h3>"))
display(widgets.HBox([dim_x_widget, dim_alpha_widget, n_qubits_widget, apply_dims_button]), dims_out)

display(gate_category)
display(fixed_gate_widget)
display(encoding_gate_widget)
display(target_qubits_widget)
display(add_gate_button)
display(remove_gate_widget, remove_gate_button)
display(gate_display)

display(widgets.HBox([observable_input, add_obs_button]))
display(widgets.HBox([remove_obs_widget, remove_obs_button]))
display(obs_display)

display(display_mode)
display(show_circuit_button, clear_button)
display(output_display)
exit_button = widgets.Button(description="Exit GUI", button_style="danger")
def exit_gui(b):
    for w in [
        dim_x_widget, dim_alpha_widget, n_qubits_widget, apply_dims_button,
        gate_category, fixed_gate_widget, encoding_gate_widget,
        target_qubits_widget, add_gate_button, remove_gate_widget, remove_gate_button,
        observable_input, add_obs_button, remove_obs_widget, remove_obs_button,
        display_mode, show_circuit_button, clear_button, exit_button,
        gate_display, obs_display, output_display, dims_out
    ]:
        w.close()

    clear_output(wait=True)
    print("GUI closed. You can now run the next cell.")

exit_button.on_click(exit_gui)
display(exit_button)

apply_dims_button.on_click(apply_dims)
gate_category.observe(update_gate_dropdown, names="value")
add_gate_button.on_click(add_gate_callback)
remove_gate_button.on_click(remove_gate_callback)
add_obs_button.on_click(add_obs_callback)
remove_obs_button.on_click(remove_obs_callback)
show_circuit_button.on_click(show_circuit)
clear_button.on_click(clear_all)


## Part 2

In [None]:
import ast, math, re, textwrap, ipywidgets as widgets
from IPython.display import display, clear_output, HTML

# ---------------- Widgets ----------------
code_input = widgets.Textarea(
    placeholder="Paste your Python/PennyLane code here...",
    description="Code Input:",
    layout=widgets.Layout(width="100%", height="320px"),
)

count_input = widgets.IntText(
    value=10000,
    description="Shots:",
    layout=widgets.Layout(width="200px")
)
display(count_input)


generate_button = widgets.Button(
    description="Generate Modified UPU with Registers + AuPAu",
    button_style="success"
)

output_display = widgets.Output()

# ---------------- AST helpers / Gate object ----------------
def ast_to_source(node):
    try:
        return ast.unparse(node)
    except Exception:
        return "<expr>"

class Gate:
    def __init__(self, opname, args_src, wires):
        self.op = opname
        self.args_src = args_src
        self.wires = wires

    def rendered(self, offset=0, replace_angle=None):
        off = [str(w + offset) for w in self.wires]
        op = self.op
        def maybe_replace(s):
            if replace_angle is None:
                return s
            try:
                return replace_angle(s)
            except Exception:
                return s
        if op.upper() in ("CNOT","CX"):
            return f"qml.CNOT(wires=({off[0]},{off[1]}))"
        if op.upper() in ("H","HADAMARD"):
            return f"qml.Hadamard(wires={off[0]})"
        if op.upper().startswith("R") and len(op)>=2 and op[1].upper() in ("X","Y","Z"):
            angle = self.args_src[0] if self.args_src else "?"
            angle = maybe_replace(angle)
            return f"qml.R{op[1].upper()}({angle}, wires={off[0]})"
        args = ", ".join([maybe_replace(a) for a in self.args_src]) if self.args_src else ""
        wires_code = f"wires={off[0]}" if len(off)==1 else f"wires=({', '.join(off)})"
        return f"qml.{op}({args}, {wires_code})" if args else f"qml.{op}({wires_code})"

# ---------------- Parsing ----------------
def parse_pennylane_code(code_text):
    tree = ast.parse(code_text)
    imports_src, build_P_src = [], ""
    dim_x, dim_alpha, observable_terms, orig_qubits = None, None, [], None

    for node in tree.body:
        if isinstance(node,(ast.Import,ast.ImportFrom)):
            seg = ast.get_source_segment(code_text,node)
            if seg: imports_src.append(seg)

    def try_eval(node):
        try: return ast.literal_eval(node)
        except: return None

    for node in tree.body:
        if isinstance(node, ast.Assign):
            for t in node.targets:
                if getattr(t,"id","")=="dim_x": dim_x=try_eval(node.value)
                if getattr(t,"id","")=="dim_alpha": dim_alpha=try_eval(node.value)
                if getattr(t,"id","")=="observable_terms": observable_terms=try_eval(node.value)
            if isinstance(node.value, ast.Call):
                func = node.value.func
                if isinstance(func, ast.Attribute) and getattr(func,"attr","")=="device":
                    for kw in node.value.keywords:
                        if kw.arg=="wires":
                            try: orig_qubits=int(ast.literal_eval(kw.value))
                            except: pass

    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and node.name=="build_P":
            seg = ast.get_source_segment(code_text,node)
            if seg: build_P_src=seg

    gates_U, gates_UPU, alpha_usage = [], [], {}

    def process_call(call_node, inside_U=False):
        opname=None
        func=call_node.func
        if isinstance(func,ast.Attribute): opname=func.attr
        elif isinstance(func,ast.Name): opname=func.id
        else: return None

        wires=[]
        for kw in call_node.keywords:
            if kw.arg=="wires":
                v=kw.value
                if isinstance(v, ast.Constant) and isinstance(v.value,int): wires=[v.value]
                elif isinstance(v,(ast.Tuple,ast.List)):
                    wires=[]
                    for elt in v.elts:
                        try: wires.append(int(ast.literal_eval(elt)))
                        except: pass
                else:
                    s=ast_to_source(v)
                    nums=[int(x) for x in re.findall(r"-?\d+",s)]
                    wires=nums
                break
        if not wires and call_node.args:
            last=call_node.args[-1]
            if isinstance(last,(ast.Tuple,ast.List)):
                wires=[]
                for elt in last.elts:
                    try: wires.append(int(ast.literal_eval(elt)))
                    except: pass
        args_src=[]
        for a in call_node.args:
            s=ast_to_source(a)
            args_src.append(s)
            if inside_U:
                m=re.findall(r"(?:α|alpha)\s*\[\s*(\d+)\s*\]", s)
                for mm in m: alpha_usage[int(mm)]=alpha_usage.get(int(mm),0)+1
                m2=re.findall(r"-\s*(?:α|alpha)\s*\[\s*(\d+)\s*\]", s)
                for mm in m2: alpha_usage[int(mm)]=alpha_usage.get(int(mm),0)+1
        if inside_U:
            for kw in call_node.keywords:
                if kw.arg!="wires":
                    s=ast_to_source(kw.value)
                    m=re.findall(r"(?:α|alpha)\s*\[\s*(\d+)\s*\]",s)
                    for mm in m: alpha_usage[int(mm)]=alpha_usage.get(int(mm),0)+1
        return Gate(opname,args_src,wires) if opname else None

    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and node.name in ("U","UPU"):
            inside_U = node.name=="U"
            target = gates_U if inside_U else gates_UPU
            for stmt in node.body:
                if isinstance(stmt,ast.Expr) and isinstance(stmt.value,ast.Call):
                    g=process_call(stmt.value,inside_U=inside_U)
                    if g: target.append(g)

    return imports_src, build_P_src, dim_x, dim_alpha, observable_terms, orig_qubits, gates_U, gates_UPU, alpha_usage

# ---------------- Code generation helpers ----------------
def freq_size_from_count(n):
    return math.ceil(math.log2(2*n+1)) if n>0 else 0
def same_gate_signature(g1,g2):
    if g1.op.upper()!=g2.op.upper(): return False
    if len(g1.wires)!=len(g2.wires): return False
    return all(a==b for a,b in zip(g1.wires,g2.wires))
def gate_depends_on_x(g):
    for s in g.args_src:
        if re.search(r"\bx\s*\[\s*\d+\s*\]",s): return True
    return False
def replace_x_with_x1(s):
    return re.sub(r"\bx\s*\[\s*(\d+)\s*\]", r"x1[\1]", s)

# ---------------- Main code generation ----------------
def generate_output_code(original_code,
                         imports_src, build_P_src,
                         dim_x, dim_alpha, observable_terms, orig_qubits,
                         gates_U, gates_UPU, alpha_usage, shots):

    dim_x = 0 if dim_x is None else dim_x
    dim_alpha = 0 if dim_alpha is None else dim_alpha
    orig_qubits = 0 if orig_qubits is None else orig_qubits

    freq_sizes={i: 2*freq_size_from_count(alpha_usage.get(i,0)) for i in range(dim_alpha)}
    n_freq=sum(freq_sizes.values())

    n_kernel=1; n_anc=1; n_circ=orig_qubits
    kernel_offset=0; freq_offset=1; anc_offset=1+n_freq; circ_offset=1+n_freq+1
    n_total=n_kernel+n_freq+n_anc+n_circ

    # Precompute start wires for each alpha register
    freq_start=freq_offset
    freq_wires_map={}
    for i in range(dim_alpha):
        size=freq_sizes[i]
        freq_wires_map[i]=list(range(freq_start,freq_start+size))
        freq_start+=size

    out=[]
    out.extend(imports_src or ["import pennylane as qml","import numpy as np"])
    out.append("")
    # Inject VGate/VdGate definitions
    out.append(textwrap.dedent("""
    # Wrapper gates for controlled Adders
    class V(qml.operation.Operation):
        num_params = 2
        par_domain = "R"
        def __init__(self, value, wires, mod):
            super().__init__(value, mod, wires=wires)


        def decomposition(self):
            # Actual gate: controlled Adder
            return [qml.Adder(self.parameters[0], x_wires=self.wires, mod=self.parameters[1])]

    # Adjoint version
    class Vdagger(V):
        def __name__(self):
            return "V†"

        def decomposition(self):
            return [qml.Adder(self.parameters[0], x_wires=self.wires, mod=self.parameters[1])]
    Vdagger.__name__ = r"$V^\dagger$"
    """))

    out.append(f"dim_x={dim_x}")
    out.append(f"dim_alpha={dim_alpha}")
    out.append(f"observable_terms={observable_terms!r}")
    out.append("")
    out.append("# build_P function")
    out.append(build_P_src or textwrap.dedent("""
        def build_P(observable_terms):
            import pennylane as qml
            P=None
            for term in observable_terms:
                pieces=term.replace('+',' ').split()
                local=None
                for p in pieces:
                    op=p[0].upper()
                    idx=int(p[1:])
                    t=getattr(qml,f'Pauli{op}')(idx)
                    local=t if local is None else local@t
                P=local if P is None else P+local
            return P if P is not None else qml.PauliZ(0)
    """).strip())
    out.append("")
    out.append(f"n_kernel={n_kernel}; freq_register_sizes={freq_sizes}; n_freq={n_freq}")
    out.append(f"n_anc={n_anc}; n_circ={n_circ}; n_total={n_total}")
    out.append(f"kernel_offset={kernel_offset}; freq_offset={freq_offset}")
    out.append(f"anc_offset={anc_offset}; circ_offset={circ_offset}")
    out.append(f"shots = {shots}")
    out.append("")
    out.append("dev=qml.device('default.qubit', wires=n_total, shots = shots)")
    out.append("")

    # U
    out.append("@qml.qnode(dev)")
    out.append("def U(x, α):")
    if not gates_U: out.append("    pass")
    else:
        for g in gates_U: out.append("    "+g.rendered(offset=circ_offset))
    out.append("    return qml.state()")
    out.append("")

    # UPU
    out.append("@qml.qnode(dev)")
    out.append("def UPU(x, α):")
    if not gates_U: out.append("    pass")
    else:
        for g in gates_U: out.append("    "+g.rendered(offset=circ_offset))
    out.append("    # build P on shifted observables")
    out.append("    shifted_terms=[]")
    out.append("    for term in observable_terms:")
    out.append("        parts=term.replace('+',' ').split()")
    out.append("        new_parts=[]")
    out.append("        for p in parts:")
    out.append("            op=p[0].upper()")
    out.append("            idx=int(p[1:])")
    out.append(f"            new_parts.append(f'{{op}}{{idx+{circ_offset}}}')")
    out.append("        shifted_terms.append(' '.join(new_parts))")
    out.append("    P=build_P(shifted_terms)")
    adj_gates=[]
    if gates_UPU:
        drop=0
        if len(gates_UPU)>=len(gates_U) and gates_U:
            matches=True
            for i,g in enumerate(gates_U):
                if not same_gate_signature(g,gates_UPU[i]): matches=False; break
            if matches: drop=len(gates_U)
        adj_gates=gates_UPU[drop:]
    if adj_gates:
        out.append("    # user-provided adjoint (trimmed prefix)")
        for g in adj_gates: out.append("    "+g.rendered(offset=circ_offset))
    else:
        out.append("    # auto-generated adjoint")
        for g in reversed(gates_U):
            op=g.op.upper()
            w=[w+circ_offset for w in g.wires]
            if op.startswith("R"):
                angle=g.args_src[0].strip()
                neg=f"-({angle})" if not angle.startswith("-") else angle[1:].strip()
                out.append(f"    qml.R{op[1]}({neg}, wires={w[0]})")
            elif op in ("CNOT","CX"): out.append(f"    qml.CNOT(wires=({w[0]},{w[1]}))")
            elif op in ("H","HADAMARD"): out.append(f"    qml.Hadamard(wires={w[0]})")
            else: out.append("    "+g.rendered(offset=circ_offset))
    out.append("    return qml.state()")
    out.append("")

    # AuPAu
    out.append("def AuPAu(x, x1, α):")
    out.append("    qml.Hadamard(wires=0)\n")
    for g in gates_U:
        alpha_matches=re.findall(r"(?:α|alpha)\s*\[\s*(\d+)\s*\]"," ".join(g.args_src))
        w=[w+circ_offset for w in g.wires]
        if alpha_matches:
            for alpha_idx in map(int, alpha_matches):
                freq_wires=freq_wires_map[alpha_idx]  # use precomputed map
                for ww in w: out.append(f"    qml.CNOT(wires=({ww},{anc_offset}))")
                out.append(f"    qml.ctrl(V, control = {anc_offset}, control_values = [0])( value = 1, wires={freq_wires}, mod=2**{freq_sizes[alpha_idx]})")
                out.append(f"    qml.ctrl(Vdagger, control = {anc_offset}, control_values = [1])( value = -1, wires={freq_wires}, mod=2**{freq_sizes[alpha_idx]})")
                for ww in w: out.append(f"    qml.CNOT(wires=({ww},{anc_offset}))")
        elif gate_depends_on_x(g) and g.op.upper().startswith("R"):
            axis=g.op[1]; angle=g.args_src[0]; angle1=replace_x_with_x1(angle)
            out.append(f"    qml.ctrl(qml.R{axis}, control=0, control_values=[1])({angle}, wires={w[0]})")
            out.append(f"    qml.ctrl(qml.R{axis}, control=0, control_values=[0])({angle1}, wires={w[0]})")
        else: out.append("    "+g.rendered(offset=circ_offset))
    # P
    out.append("\n    shifted_terms=[]")
    out.append("    for term in observable_terms:")
    out.append("        parts=term.replace('+',' ').split()")
    out.append("        new_parts=[]")
    out.append("        for p in parts:")
    out.append("            op=p[0].upper()")
    out.append("            idx=int(p[1:])")
    out.append(f"            new_parts.append(f'{{op}}{{idx+{circ_offset}}}')")
    out.append("        shifted_terms.append(' '.join(new_parts))")
    out.append("    P=build_P(shifted_terms)\n")
    # adjoint
    out.append("    # Adjoint – reverse of duplicated U")
    for g in reversed(gates_U):
        alpha_matches=re.findall(r"(?:α|alpha)\s*\[\s*(\d+)\s*\]"," ".join(g.args_src))
        w=[w+circ_offset for w in g.wires]
        if alpha_matches:
            for alpha_idx in map(int, alpha_matches):
                freq_wires=freq_wires_map[alpha_idx]  # precomputed map
                for ww in w: out.append(f"    qml.CNOT(wires=({ww},{anc_offset}))")
                out.append(f"    qml.ctrl(V, control = {anc_offset}, control_values = [1])( value = 1, wires={freq_wires}, mod=2**{freq_sizes[alpha_idx]})")
                out.append(f"    qml.ctrl(Vdagger, control = {anc_offset}, control_values = [0])( value = -1, wires={freq_wires}, mod=2**{freq_sizes[alpha_idx]})")
                for ww in w: out.append(f"    qml.CNOT(wires=({ww},{anc_offset}))")
        elif gate_depends_on_x(g) and g.op.upper().startswith("R"):
            axis=g.op[1]; angle=g.args_src[0]; angle1=replace_x_with_x1(angle)
            out.append(f"    qml.ctrl(qml.R{axis}, control=0, control_values=[0])(-{angle1}, wires={w[0]})")
            out.append(f"    qml.ctrl(qml.R{axis}, control=0, control_values=[1])(-{angle}, wires={w[0]})")
        else:
            op=g.op.upper()
            if op.startswith("R"):
                angle=g.args_src[0].strip(); neg=f"-({angle})" if not angle.startswith("-") else angle[1:].strip()
                out.append(f"    qml.R{op[1]}({neg}, wires={w[0]})")
            elif op in ("CNOT","CX"): out.append(f"    qml.CNOT(wires=({w[0]},{w[1]}))")
            elif op in ("H","HADAMARD"): out.append(f"    qml.Hadamard(wires={w[0]})")
            else: out.append("    "+g.rendered(offset=circ_offset))
    out.append("    qml.Hadamard(wires=0)")

    out.append("")
    out.append(textwrap.dedent("""
    @qml.qnode(dev)
    def AuPAu_sample(x, x1, α):
        AuPAu(x, x1, α)   # run your circuit
        return qml.sample(wires=range(n_total))

    
    # Example caller
    x0=np.zeros(dim_x); x1=np.copy(x0); α0=np.zeros(dim_alpha)
    qml.draw_mpl(U)(x0, α0); qml.draw_mpl(UPU)(x0, α0); qml.draw_mpl(AuPAu_sample)(x0, x1, α0);

    
    def conditional_Z_kernel(x, x1, α):
        # shots = 10000
        samples = AuPAu_sample(x, x1, α)
        
        # Extract each register
        kernel = samples[:, kernel_offset]     # wire 0
        anc = samples[:, anc_offset]           # wire 4
        circ = samples[:, circ_offset]         # wire 5

        # Condition: anc=0 and circ=0
        mask = (anc == 0) & (circ == 0)
        conditioned = kernel[mask]

        total = shots
        kept = conditioned.shape[0]

        if kept == 0:
            return {
                "total_shots": total,
                "kept_shots": 0,
                "expectation": None,
            }

        # convert bits to Z eigenvalues (+1 for 0, -1 for 1)
        zvals = 1 - 2*conditioned       # maps 0→+1, 1→−1

        return {
            "total_shots": total,
            "kept_shots": kept,
            "fraction_kept": kept / total,
            "expectation": np.mean(zvals),
        }
        
    def compute_kernel_matrix(X, α):
        n = len(X)
        K = np.zeros((n, n))

        for i in range(n):
            for j in range(n):
                res = conditional_Z_kernel(X[i], X[j], α)

                # Store expectation (or 0 if no successful post-selection)
                K[i, j] = res["expectation"] if res["expectation"] is not None else 0.0

                # # Print diagnostics
                # print(f"({i}, {j})")
                # print(f"  total shots:     {res['total_shots']}")
                # print(f"  kept shots:      {res['kept_shots']}")
                # print(f"  fraction kept:   {res['fraction_kept']:.6f}")
                # print(f"  expectation:     {res['expectation']}")
                # print()
        return K
    
    X = [np.array([0.0]), np.array([1.2]), np.array([-0.7])]
    α = np.array([0.5])

    K = compute_kernel_matrix(X, α)
    print("Kernel Matrix from Code:")
    print(K)
    """))
    return "\n".join(out)

# ---------------- UI ----------------
def show_code_block(code_text):
    html = f"""
    <textarea id="code_area" style="width:100%; height:320px; background:#f5f5f5; border:1px solid #ccc; font-size:13px; white-space:pre-wrap; overflow:auto;">{code_text}</textarea>
    <br>
    <button onclick="
        var area = document.getElementById('code_area');
        area.select();
        document.execCommand('copy');
    ">Copy to Clipboard</button>
    """
    with output_display:
        clear_output()
        display(HTML(html))

def generate_modified(b):
    code_text = code_input.value
    shots = count_input.value
    try:
        parsed=parse_pennylane_code(code_text)
        output_code=generate_output_code(code_text,*parsed, shots)
        show_code_block(output_code)
    except Exception as e:
        with output_display:
            clear_output()
            print("Error:", e)

generate_button.on_click(generate_modified)
display(code_input, generate_button, output_display)

exit_button=widgets.Button(description="Exit GUI", button_style="danger")
def exit_gui(b):
    clear_output(wait=True)
    print("GUI closed. You can now run the next cell.")

exit_button.on_click(exit_gui)
display(exit_button)


IntText(value=10000, description='Shots:', layout=Layout(width='200px'))

Textarea(value='', description='Code Input:', layout=Layout(height='320px', width='100%'), placeholder='Paste …

Button(button_style='success', description='Generate Modified UPU with Registers + AuPAu', style=ButtonStyle()…

Output()

Button(button_style='danger', description='Exit GUI', style=ButtonStyle())