In [None]:
import pennylane as qml
from pennylane import numpy as np

def PQC(parameters, wires, time_step_index, depth=3):
    """Parameterized Quantum Circuit (PQC) for time evolution.

    Args:
        parameters (array[float]): Array of shape (depth, len(wires), 2) containing rotation angles.
        wires (list[int]): List of wire indices to apply the circuit on.
        time_step_index (int): Index of the current time step.
        depth (int): Number of layers in the PQC.
    """
    for depth_index in range(depth):
        for wire in wires:
            qml.RX(parameters[depth_index, wire, 0], wires=wire)
            qml.RY(parameters[depth_index, wire, 1], wires=wire)
        for i in range(1, len(wires)+1):
            qml.CNOT(wires=[wires[i], wires[i - 1]])
        qml.CNOT(wires=[wires[0], wires[-1]])
        
class QFM():
    def __init__(self, input_samples, output_samples, n_ancilla=3, num_time_steps=3, depth_per_time_step=3):
        """Quantum Flow Model (QFM) for learning time evolution.

        Args:
            input_samples (array[float]): Array of shape (num_samples, num_wires) with initial states.
            output_samples (array[float]): Array of shape (num_samples, num_wires) with target states.
            num_time_steps (int): Number of time steps to model.
            depth_per_time_step (int): Depth of the PQC for each time step.
        """
        
        self.input_samples = input_samples
        self.output_samples = output_samples
        self.num_time_steps = num_time_steps
        self.depth_per_time_step = depth_per_time_step
        self.n_input = int(np.log2(input_samples.shape[1]))
        self.n_output = int(np.log2(input_samples.shape[1]))
        assert input_samples.shape[0] == output_samples.shape[0], "Input and output samples must have the same number of samples."
        assert input_samples.shape[1] == output_samples.shape[1], "Input and output samples must have the same number of features."
        
        self.n_ancilla = np.abs(self.n_output - self.n_input) + n_ancilla
        self.num_wires = max(self.n_input, self.n_output) + self.n_ancilla # add ancilla qubits to approximate non-unitary evolution
        
        # Initialize parameters for the PQC
        self.parameters = np.random.uniform(0, 2 * np.pi, 
                                            (num_time_steps, depth_per_time_step, self.num_wires, 2), 
                                            requires_grad=True)
        
        # Define the quantum device
        self.dev = qml.device("lightning.qubit", wires=self.num_wires + 1) # +1 for a qubit for swap-test
        
        # Define the quantum node
        @qml.qnode(self.dev, interface='autograd')
        def circuit(params, t, x):
            qml.templates.MottonenStatePreparation(x, wires=self.sys_wires)
            PQC(params, wires=range(self.num_wires), time_step_index=t, depth=self.depth_per_time_step)
            
        
        self.circuit_per_time_step = circuit