In [None]:
#!pip install qlearnkit['pennylane']
#!pip install --upgrade scipy pennylane
#!pip install pennylane
#!pip install --upgrade numpy pennylane
#!pip install pennylane-lightning

In [None]:
import torch
import torch.nn as nn
import pennylane as qml

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [None]:
class QGRU(nn.Module):

    def custom_encoding(self, inputs, wires):
        # Apply Hadamard to each qubit to create an unbiased initial state
        for wire in range(self.n_qubits):
            qml.Hadamard(wires=wire)

        qml.templates.AngleEmbedding(torch.sin(inputs), rotation='Y', wires=wires)
        qml.templates.AngleEmbedding(torch.cos(inputs ** 2), rotation='Z', wires=wires)


    def custom_entangler_layer(self, weights, wires):
        for l in range(self.n_qlayers):

            qml.CNOT(wires=[0, 1])
            qml.CNOT(wires=[1, 2])
            qml.CNOT(wires=[2, 3])
            qml.CNOT(wires=[3, 0])
            qml.CNOT(wires=[0, 2])
            qml.CNOT(wires=[1, 3])
            qml.CNOT(wires=[2, 0])
            qml.CNOT(wires=[3, 1])

            # Apply general rotation for each qubit
            for i, wire in enumerate(wires):
                qml.Rot(*weights[l, i, :], wires=wire)


    def __init__(self,
                input_size,
                hidden_size,
                n_qubits=4,
                n_qlayers=4,
                batch_first=True,
                return_sequences=False,
                return_state=False,
                backend='default.qubit'):
        super(QGRU, self).__init__()
        self.n_inputs = input_size
        self.hidden_size = hidden_size
        self.concat_size = self.n_inputs + self.hidden_size
        self.n_qubits = n_qubits
        self.n_qlayers = n_qlayers
        self.backend = backend  # "default.qubit", "qiskit.basicaer", "qiskit.ibm"

        self.batch_first = batch_first
        self.return_sequences = return_sequences
        self.return_state = return_state

        self.dev_reset = qml.device('default.qubit', wires=range(self.n_qubits))
        self.dev_update = qml.device('default.qubit', wires=range(self.n_qubits))
        self.dev_new = qml.device('default.qubit', wires=range(self.n_qubits))

        # Reset gate
        def _circuit_reset(inputs, weights):
            self.custom_encoding(inputs, wires=range(self.n_qubits))
            self.custom_entangler_layer(weights, wires=range(self.n_qubits))
            return [qml.expval(qml.PauliZ(wires=w)) for w in range(self.n_qubits)]
        self.qlayer_reset = qml.QNode(_circuit_reset, self.dev_reset, interface="torch")

        # Update gate
        def _circuit_update(inputs, weights):
            self.custom_encoding(inputs, wires=range(self.n_qubits))
            self.custom_entangler_layer(weights, wires=range(self.n_qubits))
            return [qml.expval(qml.PauliZ(wires=w)) for w in range(self.n_qubits)]
        self.qlayer_update = qml.QNode(_circuit_update, self.dev_update, interface="torch")

        # New gate
        def _circuit_new(inputs, weights):
            self.custom_encoding(inputs, wires=range(self.n_qubits))
            self.custom_entangler_layer(weights, wires=range(self.n_qubits))
            return [qml.expval(qml.PauliZ(wires=w)) for w in range(self.n_qubits)]
        self.qlayer_new = qml.QNode(_circuit_new, self.dev_new, interface="torch")

        weight_shapes = {"weights": (self.n_qlayers, self.n_qubits, 3)}

        self.clayer_in = torch.nn.Linear(self.hidden_size + self.n_inputs, self.n_qubits)

        self.VQC = {
            'reset': qml.qnn.TorchLayer(self.qlayer_reset, weight_shapes),
            'update': qml.qnn.TorchLayer(self.qlayer_update, weight_shapes),
            'new': qml.qnn.TorchLayer(self.qlayer_new, weight_shapes)
        }
        self.clayer_out = torch.nn.Linear(self.n_qubits, self.hidden_size)

    def forward(self, x, init_states=None):
        '''
        x.shape is (batch_size, seq_length, feature_size)
        recurrent_activation -> sigmoid
        activation -> tanh
        '''
        if self.batch_first is True:
            batch_size, seq_length, features_size = x.size()
        else:
            seq_length, batch_size, features_size = x.size()

        if init_states is None:
            h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)  # hidden state
        else:
            h_t = init_states

        output_seq = []

        for t in range(seq_length):
            x_t = x[:, t, :]

            # Concatenate input and hidden state
            v_t = torch.cat((h_t, x_t), dim=1)

            v_t = self.clayer_in(v_t)

            r_t = torch.sigmoid(self.clayer_out(self.VQC['reset'](v_t)))  # reset gate

            z_t = torch.sigmoid(self.clayer_out(self.VQC['update'](v_t)))  # update gate

            combined_r = r_t * h_t
            v_t_new = torch.cat((combined_r, x_t), dim=1)
            y_t_new = self.clayer_in(v_t_new)

            h_tilde = torch.tanh(self.clayer_out(self.VQC['new'](y_t_new)))  # new gate

            # Compute the new hidden state
            h_t = z_t * h_t + (1 - z_t) * h_tilde

            if self.return_sequences:
                output_seq.append(h_t.unsqueeze(1))

        if self.return_sequences:
            output_seq = torch.cat(output_seq, dim=1)
            return output_seq
        else:
            return h_t