In [1]:
from qiskit_ibm_runtime import QiskitRuntimeService, Options, Sampler, Session, Estimator

from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

from torchquantum.measurement import expval_joint_analytical, expval_joint_sampling, expval_joint_sampling_grouping

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchquantum as tq

import torchquantum.functional as tqf

from torchquantum.plugin.qiskit import tq2qiskit


In [3]:
from dotenv import dotenv_values

config = dotenv_values(".env")

In [4]:
service = QiskitRuntimeService(channel="ibm_quantum", token=config["IBM_TOKEN"])

RequestsApiError: 'HTTPSConnectionPool(host=\'auth.quantum-computing.ibm.com\', port=443): Max retries exceeded with url: /api/version (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x72a978defd40>: Failed to resolve \'auth.quantum-computing.ibm.com\' ([Errno -3] Temporary failure in name resolution)"))'

In [None]:
service.backends()

[<IBMBackend('simulator_mps')>,
 <IBMBackend('simulator_statevector')>,
 <IBMBackend('simulator_stabilizer')>,
 <IBMBackend('ibm_brisbane')>,
 <IBMBackend('ibm_kyoto')>,
 <IBMBackend('ibm_sherbrooke')>,
 <IBMBackend('ibmq_qasm_simulator')>,
 <IBMBackend('simulator_extended_stabilizer')>,
 <IBMBackend('ibm_osaka')>]

In [None]:
backend = service.backend("ibmq_qasm_simulator")

----

In [4]:
class MultiHeadAttentionBase(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 dropout: float = 0.1,
                 mask=None,
                 use_bias=False):
        super(MultiHeadAttentionBase, self).__init__()

        assert embed_dim % num_heads == 0, f"Embedding dimension ({embed_dim}) should be divisible by number of heads ({num_heads})"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads  # projection dimensions
        self.k_linear = None
        self.q_linear = None
        self.v_linear = None
        self.combine_heads = None
        self.dropout = nn.Dropout(dropout)
        self.attn_weights = None
    
    def separate_heads(self, x):
        '''
        split into N heads
        from (batch_size, seq_len, embed_dim)
        to   (batch_size, seq_len, num_heads, embed_dim)
        then transpose (1,2) to (batch_size, num_heads, seq_len, embed_dim)
        to make mat mult straightforward for each head
        '''
        batch_size = x.size(0)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def attention(self, query, key, value, mask=None, dropout=None):
        '''
        Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k))V
        '''
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        # see also: https://tensorchiefs.github.io/dlday2018/tutorial/einsum.html
        #scores = torch.einsum('bijh, bkjh -> bikh', query, key) / math.sqrt(self.d_k)
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        attn = torch.matmul(scores, value)
        return attn, scores
    
    def downstream(self, query, key, value, batch_size, mask=None):
        Q = self.separate_heads(query)
        K = self.separate_heads(key)
        V = self.separate_heads(value)

        x, self.attn_weights = self.attention(Q, K, V, mask, dropout=self.dropout)

        concat = x.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

        return concat
        # output = self.combine_heads(concat)
        # return output

   # def forward(self, x, mask=None):
    #    raise NotImplementedError("Base class does not execute forward function.")


In [5]:
class MultiHeadAttentionClassical(MultiHeadAttentionBase):
    
    def __init__(self, embed_dim: int,
                 num_heads: int,
                 dropout=0.1,
                 mask=None,
                 use_bias=False):
        super(MultiHeadAttentionClassical, self).__init__(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, mask=mask, use_bias=use_bias)

        self.k_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.q_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.v_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.combine_heads = nn.Linear(embed_dim, embed_dim, bias=use_bias)
    
    def forward(self, x, mask=None):
        batch_size, _ , embed_dim = x.size()
        assert embed_dim == self.embed_dim, f"Input embedding ({embed_dim}) does not match layer embedding size ({self.embed_dim})"

        K = self.k_linear(x)
        Q = self.q_linear(x)
        V = self.v_linear(x)

        x = self.downstream(Q, K, V, batch_size, mask)
        output = self.combine_heads(x)
        return output

In [6]:
class QLayer(tq.QuantumModule):
        def __init__(self, n_qbits, D_ansatz=1, *args, **kwargs):
            super().__init__()    
            self.n_wires = n_qbits
            self.encoder = tq.GeneralEncoder(
                    [{'input_idx': [i], 'func': 'rx', 'wires': [i]} for i in range(self.n_wires)])
            #self.rx_list = [tq.RX(has_params=True, trainable=True) for _ in range(self.n_wires)]
            #self.ry_test = tq.RY(has_params=True, trainable=True)
            #self.measure = tq.MeasureAll(tq.PauliZ)

            if n_qbits >= 2:
                self.rx_0 = tq.RX(has_params=True, trainable=True)
                self.rx_1 = tq.RX(has_params=True, trainable=True)
                self.ry_1_0 = tq.RY(has_params=True, trainable=True)
                self.ry_1_1 = tq.RY(has_params=True, trainable=True)
                self.ry_2_0 = tq.RY(has_params=True, trainable=True)
                self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 4:
                self.rx_2 = tq.RX(has_params=True, trainable=True)
                self.rx_3 = tq.RX(has_params=True, trainable=True)
                self.ry_1_2 = tq.RY(has_params=True, trainable=True)
                self.ry_1_3 = tq.RY(has_params=True, trainable=True)
                self.ry_2_2 = tq.RY(has_params=True, trainable=True)
                self.ry_2_3 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 8:
                self.rx_4 = tq.RX(has_params=True, trainable=True)
                self.rx_5 = tq.RX(has_params=True, trainable=True)
                self.rx_6 = tq.RX(has_params=True, trainable=True)
                self.rx_7 = tq.RX(has_params=True, trainable=True)
                self.ry_1_4 = tq.RY(has_params=True, trainable=True)
                self.ry_1_5 = tq.RY(has_params=True, trainable=True)
                self.ry_1_6 = tq.RY(has_params=True, trainable=True)
                self.ry_1_7 = tq.RY(has_params=True, trainable=True)
                self.ry_2_4 = tq.RY(has_params=True, trainable=True)
                self.ry_2_5 = tq.RY(has_params=True, trainable=True)
                self.ry_2_6 = tq.RY(has_params=True, trainable=True)
                self.ry_2_7 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 16:
                self.rx_8 = tq.RX(has_params=True, trainable=True)
                self.rx_9 = tq.RX(has_params=True, trainable=True)
                self.rx_10 = tq.RX(has_params=True, trainable=True)
                self.rx_11 = tq.RX(has_params=True, trainable=True)
                self.rx_12 = tq.RX(has_params=True, trainable=True)
                self.rx_13 = tq.RX(has_params=True, trainable=True)
                self.rx_14 = tq.RX(has_params=True, trainable=True)
                self.rx_15 = tq.RX(has_params=True, trainable=True)
                self.ry_1_8 = tq.RY(has_params=True, trainable=True)
                self.ry_1_9 = tq.RY(has_params=True, trainable=True)
                self.ry_1_10 = tq.RY(has_params=True, trainable=True)
                self.ry_1_11 = tq.RY(has_params=True, trainable=True)
                self.ry_1_12 = tq.RY(has_params=True, trainable=True)
                self.ry_1_13 = tq.RY(has_params=True, trainable=True)
                self.ry_1_14 = tq.RY(has_params=True, trainable=True)
                self.ry_1_15 = tq.RY(has_params=True, trainable=True)
                self.ry_2_8 = tq.RY(has_params=True, trainable=True)
                self.ry_2_9 = tq.RY(has_params=True, trainable=True)
                self.ry_2_10 = tq.RY(has_params=True, trainable=True)
                self.ry_2_11 = tq.RY(has_params=True, trainable=True)
                self.ry_2_12 = tq.RY(has_params=True, trainable=True)
                self.ry_2_13 = tq.RY(has_params=True, trainable=True)
                self.ry_2_14 = tq.RY(has_params=True, trainable=True)
                self.ry_2_15 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 32:
                self.rx_16 = tq.RX(has_params=True, trainable=True)
                self.rx_17 = tq.RX(has_params=True, trainable=True)
                self.rx_18 = tq.RX(has_params=True, trainable=True)
                self.rx_19 = tq.RX(has_params=True, trainable=True)
                self.rx_20 = tq.RX(has_params=True, trainable=True)
                self.rx_21 = tq.RX(has_params=True, trainable=True)
                self.rx_22 = tq.RX(has_params=True, trainable=True)
                self.rx_23 = tq.RX(has_params=True, trainable=True)
                self.rx_24 = tq.RX(has_params=True, trainable=True)
                self.rx_25 = tq.RX(has_params=True, trainable=True)
                self.rx_26 = tq.RX(has_params=True, trainable=True)
                self.rx_27 = tq.RX(has_params=True, trainable=True)
                self.rx_28 = tq.RX(has_params=True, trainable=True)
                self.rx_29 = tq.RX(has_params=True, trainable=True)
                self.rx_30 = tq.RX(has_params=True, trainable=True)
                self.rx_31 = tq.RX(has_params=True, trainable=True)
                self.ry_1_16 = tq.RY(has_params=True, trainable=True)
                self.ry_1_17 = tq.RY(has_params=True, trainable=True)
                self.ry_1_18 = tq.RY(has_params=True, trainable=True)
                self.ry_1_19 = tq.RY(has_params=True, trainable=True)
                self.ry_1_20 = tq.RY(has_params=True, trainable=True)
                self.ry_1_21 = tq.RY(has_params=True, trainable=True)
                self.ry_1_22 = tq.RY(has_params=True, trainable=True)
                self.ry_1_23 = tq.RY(has_params=True, trainable=True)
                self.ry_1_24 = tq.RY(has_params=True, trainable=True)
                self.ry_1_25 = tq.RY(has_params=True, trainable=True)
                self.ry_1_26 = tq.RY(has_params=True, trainable=True)
                self.ry_1_27 = tq.RY(has_params=True, trainable=True)
                self.ry_1_28 = tq.RY(has_params=True, trainable=True)
                self.ry_1_29 = tq.RY(has_params=True, trainable=True)
                self.ry_1_30 = tq.RY(has_params=True, trainable=True)
                self.ry_1_31 = tq.RY(has_params=True, trainable=True)
                self.ry_2_16 = tq.RY(has_params=True, trainable=True)
                self.ry_2_17 = tq.RY(has_params=True, trainable=True)
                self.ry_2_18 = tq.RY(has_params=True, trainable=True)
                self.ry_2_19 = tq.RY(has_params=True, trainable=True)
                self.ry_2_20 = tq.RY(has_params=True, trainable=True)
                self.ry_2_21 = tq.RY(has_params=True, trainable=True)
                self.ry_2_22 = tq.RY(has_params=True, trainable=True)
                self.ry_2_23 = tq.RY(has_params=True, trainable=True)
                self.ry_2_24 = tq.RY(has_params=True, trainable=True)
                self.ry_2_25 = tq.RY(has_params=True, trainable=True)
                self.ry_2_26 = tq.RY(has_params=True, trainable=True)
                self.ry_2_27 = tq.RY(has_params=True, trainable=True)
                self.ry_2_28 = tq.RY(has_params=True, trainable=True)
                self.ry_2_29 = tq.RY(has_params=True, trainable=True)
                self.ry_2_30 = tq.RY(has_params=True, trainable=True)
                self.ry_2_31 = tq.RY(has_params=True, trainable=True)

            #self.observables = SparsePauliOp.from_list([("ZX" + "I"*(n_qbits - 2),1), ("XY" + "I"*(n_qbits - 2),1)] + [("I"*i + "Z" + "I"*(n_qbits - 1 -i), 1) for i in range(n_qbits)])
            #self.measure = tq.MeasureAll(tq.PauliZ)
            #self.measure = tq.MeasureMultiPauliSum(self.observables)

        def ansatz_gate_forward_rx(self, q_device):
            if self.n_wires >= 2:
                self.rx_0(q_device, wires=0)
                self.rx_1(q_device, wires=1)
            if self.n_wires >= 4:
                self.rx_2(q_device, wires=2)
                self.rx_3(q_device, wires=3)
            if self.n_wires >= 8:
                self.rx_4(q_device, wires=4)
                self.rx_5(q_device, wires=5)
                self.rx_6(q_device, wires=6)
                self.rx_7(q_device, wires=7)
            if self.n_wires >= 16:
                self.rx_8(q_device, wires=8)
                self.rx_9(q_device, wires=9)
                self.rx_10(q_device, wires=10)
                self.rx_11(q_device, wires=11)
                self.rx_12(q_device, wires=12)
                self.rx_13(q_device, wires=13)
                self.rx_14(q_device, wires=14)
                self.rx_15(q_device, wires=15)
            if self.n_wires >= 32:
                self.rx_16(q_device, wires=16)
                self.rx_17(q_device, wires=17)
                self.rx_18(q_device, wires=18)
                self.rx_19(q_device, wires=19)
                self.rx_20(q_device, wires=20)
                self.rx_21(q_device, wires=21)
                self.rx_22(q_device, wires=22)
                self.rx_23(q_device, wires=23)
                self.rx_24(q_device, wires=24)
                self.rx_25(q_device, wires=25)
                self.rx_26(q_device, wires=26)
                self.rx_27(q_device, wires=27)
                self.rx_28(q_device, wires=26)
                self.rx_29(q_device, wires=29)
                self.rx_30(q_device, wires=30)
                self.rx_31(q_device, wires=31)
            
        def ansatz_gate_forward_ry_1(self, q_device):
            if self.n_wires >= 2:
                self.ry_1_0(q_device, wires=0)
                self.ry_1_1(q_device, wires=1)
            if self.n_wires >= 4:
                self.ry_1_2(q_device, wires=2)
                self.ry_1_3(q_device, wires=3)
            if self.n_wires >= 8:
                self.ry_1_4(q_device, wires=4)
                self.ry_1_5(q_device, wires=5)
                self.ry_1_6(q_device, wires=6)
                self.ry_1_7(q_device, wires=7)
            if self.n_wires >= 16:
                self.ry_1_8(q_device, wires=8)
                self.ry_1_9(q_device, wires=9)
                self.ry_1_10(q_device, wires=10)
                self.ry_1_11(q_device, wires=11)
                self.ry_1_12(q_device, wires=12)
                self.ry_1_13(q_device, wires=13)
                self.ry_1_14(q_device, wires=14)
                self.ry_1_15(q_device, wires=15)
            if self.n_wires >= 32:
                self.ry_1_16(q_device, wires=16)
                self.ry_1_17(q_device, wires=17)
                self.ry_1_18(q_device, wires=18)
                self.ry_1_19(q_device, wires=19)
                self.ry_1_20(q_device, wires=20)
                self.ry_1_21(q_device, wires=21)
                self.ry_1_22(q_device, wires=22)
                self.ry_1_23(q_device, wires=23)
                self.ry_1_24(q_device, wires=24)
                self.ry_1_25(q_device, wires=25)
                self.ry_1_26(q_device, wires=26)
                self.ry_1_27(q_device, wires=27)
                self.ry_1_28(q_device, wires=26)
                self.ry_1_29(q_device, wires=29)
                self.ry_1_30(q_device, wires=30)
                self.ry_1_31(q_device, wires=31)

        def ansatz_gate_forward_ry_2(self, q_device):
            if self.n_wires >= 2:
                self.ry_2_0(q_device, wires=0)
                self.ry_2_1(q_device, wires=1)
            if self.n_wires >= 4:
                self.ry_2_2(q_device, wires=2)
                self.ry_2_3(q_device, wires=3)
            if self.n_wires >= 8:
                self.ry_2_4(q_device, wires=4)
                self.ry_2_5(q_device, wires=5)
                self.ry_2_6(q_device, wires=6)
                self.ry_2_7(q_device, wires=7)
            if self.n_wires >= 16:
                self.ry_2_8(q_device, wires=8)
                self.ry_2_9(q_device, wires=9)
                self.ry_2_10(q_device, wires=10)
                self.ry_2_11(q_device, wires=11)
                self.ry_2_12(q_device, wires=12)
                self.ry_2_13(q_device, wires=13)
                self.ry_2_14(q_device, wires=14)
                self.ry_2_15(q_device, wires=15)
            if self.n_wires >= 32:
                self.ry_2_16(q_device, wires=16)
                self.ry_2_17(q_device, wires=17)
                self.ry_2_18(q_device, wires=18)
                self.ry_2_19(q_device, wires=19)
                self.ry_2_20(q_device, wires=20)
                self.ry_2_21(q_device, wires=21)
                self.ry_2_22(q_device, wires=22)
                self.ry_2_23(q_device, wires=23)
                self.ry_2_24(q_device, wires=24)
                self.ry_2_25(q_device, wires=25)
                self.ry_2_26(q_device, wires=26)
                self.ry_2_27(q_device, wires=27)
                self.ry_2_28(q_device, wires=26)
                self.ry_2_29(q_device, wires=29)
                self.ry_2_30(q_device, wires=30)
                self.ry_2_31(q_device, wires=31)
        
        @tq.static_support
        def forward(self, q_device, x, return_q_device=False):
            self.encoder(q_device, x)

            self.ansatz_gate_forward_rx(q_device)
            self.ansatz_gate_forward_ry_1(q_device)

            for k in range(self.n_wires):
                if k==self.n_wires-1:
                    tqf.cnot(q_device, wires=[k, 0], static=self.static_mode, parent_graph=self.graph) 
                else:
                    tqf.cnot(q_device, wires=[k, k+1], static=self.static_mode, parent_graph=self.graph)

            self.ansatz_gate_forward_ry_2(q_device)

            q_device = q_device.bfloat16()
            
            if return_q_device:
                return q_device


In [7]:
class MultiHeadAttentionQuantum(MultiHeadAttentionBase):
    
            
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 dropout=0.1,
                 mask=None,
                 use_bias=False,
                 n_qubits: int = 4,
                 n_qlayers: int = 1,
                 nb_shots=1024,
                 coeff_amp=0,
                 q_device="default.qubit"):
        super(MultiHeadAttentionQuantum, self).__init__(embed_dim, num_heads, dropout=dropout, mask=mask, use_bias=use_bias)
        
        # todo: add intermediate layer to "dress" quantum circuit
        assert n_qubits == embed_dim, "Number of qubits ({n_qubits}) does not match embedding dim ({embed_dim})"
        self.n_qubits = n_qubits
        #self.n_qlayers = n_qlayers
        self.coeff_amp = coeff_amp

        self.k_observables = ["ZX" *(n_qubits//2)]
        self.q_observables = ["ZX" *(n_qubits//2)]
        self.v_observables = ["I"*i + "ZX"+ "I"*(n_qubits - 2 - i) for i in range(n_qubits - 1)] + ["XX" + "I"*(n_qubits - 2)]

        #self.k_observables_reversed = ["".join(reversed(obs)) for obs in self.k_observables]
        #self.q_observables_reversed = ["".join(reversed(obs)) for obs in self.q_observables]
        #self.v_observables_reversed = ["".join(reversed(obs)) for obs in self.v_observables]

        self.k_layer = QLayer(n_qubits)
        self.q_layer = QLayer(n_qubits)
        self.v_layer = QLayer(n_qubits)
        #self.measure = tq.MeasureAll(tq.PauliZ)
        self.q_device = q_device
        self.nb_shots = nb_shots

    def get_exp_from_observables(self, x, quantum_layer, observables, session=None, return_quant_exec_time = False):
        
        q_dev = tq.QuantumDevice(n_wires=self.n_qubits, device=self.q_device, bsz=x.shape[0])

        if session is not None:
            options = Options(optimization_level=1, execution={"shots":self.nb_shots})
            estimator = Estimator(session=session, options=options)

            all_batch = []
            all_time = []
            for i in range(x.shape[0]):
                job = estimator.run(circuits=[tq2qiskit(q_device=q_dev, m=quantum_layer, x=torch.unsqueeze(x[i], dim=0)) for o in range(len(observables))],
                                     observables=observables)
                all_batch.append(job.result().values)
                all_time.append(job.usage_estimation)
            
            all_batch = torch.tensor(all_batch).float()
            
            if return_quant_exec_time:
                return all_batch, all_time
            
            else:
                return all_batch
            
        else:

            if len(observables) > 1:

                observables_reversed = ["".join(reversed(obs)) for obs in observables]
                re_order_dict = {}
                expval = expval_joint_sampling_grouping(qdev=quantum_layer(q_dev, x, return_q_device=True), observables=observables_reversed, n_shots_per_group=self.nb_shots)
                
                for obs, value in expval.items():
                    re_order_dict["".join(reversed(obs))] = value

                print([re_order_dict[obs] for obs in observables])
                return torch.stack([re_order_dict[obs] for obs in observables], dim=-1).float() # dim : [bsz, embed_dim]
            else:

                observable_reversed = "".join(reversed(observables[0]))
                return expval_joint_sampling(qdev=quantum_layer(q_dev, x, return_q_device=True), observable=observable_reversed, n_shots=self.nb_shots).reshape([x.shape[0],-1]).float() # dim : [bsz, embed_dim]


    def forward(self, x, mask=None, session=None):
        batch_size, seq_len, embed_dim = x.size()
        assert embed_dim == self.embed_dim, f"Input embedding ({embed_dim}) does not match layer embedding size ({self.embed_dim})"

        v_exp_val = []
        k_exp_val = []
        q_exp_val = []

        for t in range(seq_len):

            v_exp_val.append(self.get_exp_from_observables(x=x[:, t, :].clone(), quantum_layer=self.v_layer, observables=self.v_observables, session=session))
            k_exp_val.append(self.get_exp_from_observables(x=x[:, t, :].clone(), quantum_layer=self.k_layer, observables=self.k_observables, session=session))
            q_exp_val.append(self.get_exp_from_observables(x=x[:, t, :].clone(), quantum_layer=self.q_layer, observables=self.q_observables, session=session))
        
        print(v_exp_val)
        print(k_exp_val)
        print(q_exp_val)

        V = torch.transpose(torch.stack(v_exp_val), 0, 1)

        K = torch.squeeze(torch.transpose(torch.stack(k_exp_val), 0, 1), dim= -1)
        Q = torch.squeeze(torch.transpose(torch.stack(q_exp_val), 0, 1), dim= -1)

        for _ in range(self.coeff_amp):
            K = torch.sin((torch.pi / 2) * K)
            Q = torch.sin((torch.pi / 2) * Q)

        K = torch.exp((torch.pi / 2) * torch.complex(torch.zeros([batch_size, seq_len], dtype=torch.float32), K))

        Q = torch.exp( - (torch.pi / 2) * torch.complex(torch.zeros([batch_size, seq_len], dtype=torch.float32), Q))
         
        print(V)
        print(K)
        print(Q)
        #print(torch.unsqueeze(Q, dim=0) * torch.unsqueeze(K, dim=0).T)
        #print((torch.unsqueeze(Q, dim=0) * torch.unsqueeze(K, dim=0).T).real @ V)

        raise RuntimeError("STOP")

            
            #print(torch.stack([torch.stack(list(expval_joint_sampling_grouping(qdev=self.v_layer(q_dev, x[:, t, :].clone(), return_q_device=True), observables=self.v_observables, n_shots_per_group=1024).values())) for t in range(seq_len)]))
            #V = [self.v_layer(q_dev, x[:, t, :].clone()).values for t in range(seq_len)]
        
        #print(V)
        
        #K = [self.q_layer(x[:, t, :].clone(),q_dev) for t in range(seq_len)]
        #Q = [self.q_layer(x[:, t, :].clone(),q_dev) for t in range(seq_len)]
        #V = [self.q_layer(x[:, t, :].clone(),q_dev) for t in range(seq_len)]

        K = torch.Tensor(pad_sequence(K))
        Q = torch.Tensor(pad_sequence(Q))
        V = torch.Tensor(pad_sequence(V))
        x = self.downstream(Q, K, V, batch_size, mask)
        #output = [self.q_layer(x[:, t, :],q_dev) for t in range(seq_len)]
        #output = torch.Tensor(pad_sequence(output)).clone()
        return x


In [8]:
torch.sin((torch.pi/2)*torch.tensor([-0.7, 0.5, 0.999]))

tensor([-0.8910,  0.7071,  1.0000])

In [9]:
EMBED_DIM = 8

BATCH_SIZE = 2

SEQ_LEN = 3


In [10]:
test_input = torch.tensor(np.random.rand(BATCH_SIZE, SEQ_LEN, EMBED_DIM), dtype=torch.float32)

In [16]:
test_input

tensor([[[0.6386, 0.4780, 0.7459, 0.7453, 0.0297, 0.5790, 0.9100, 0.2792],
         [0.0707, 0.8234, 0.5156, 0.9364, 0.3913, 0.1841, 0.6590, 0.7179],
         [0.9926, 0.4099, 0.0640, 0.2724, 0.3471, 0.5662, 0.7794, 0.8999]],

        [[0.4892, 0.0719, 0.7854, 0.0900, 0.1699, 0.6680, 0.9759, 0.7946],
         [0.9557, 0.8346, 0.5996, 0.9486, 0.9890, 0.3669, 0.3968, 0.6575],
         [0.7172, 0.9581, 0.6715, 0.5754, 0.2382, 0.9599, 0.9547, 0.2736]]])

In [13]:
k = torch.tensor(np.random.rand(BATCH_SIZE, SEQ_LEN), dtype=torch.float32)

In [17]:
k

tensor([[0.6273, 0.4358, 0.6089],
        [0.5537, 0.7491, 0.5370]])

In [18]:
torch.unsqueeze(k, dim=-1) * test_input

tensor([[[0.4006, 0.2999, 0.4679, 0.4675, 0.0186, 0.3632, 0.5708, 0.1751],
         [0.0308, 0.3588, 0.2247, 0.4080, 0.1705, 0.0802, 0.2872, 0.3128],
         [0.6044, 0.2495, 0.0390, 0.1659, 0.2114, 0.3447, 0.4745, 0.5479]],

        [[0.2709, 0.0398, 0.4349, 0.0498, 0.0941, 0.3699, 0.5403, 0.4400],
         [0.7159, 0.6252, 0.4492, 0.7106, 0.7408, 0.2748, 0.2972, 0.4925],
         [0.3851, 0.5145, 0.3606, 0.3090, 0.1279, 0.5154, 0.5127, 0.1469]]])

In [15]:
test_input * torch.unsqueeze(k, dim=-1)

tensor([[[0.4006, 0.2999, 0.4679, 0.4675, 0.0186, 0.3632, 0.5708, 0.1751],
         [0.0308, 0.3588, 0.2247, 0.4080, 0.1705, 0.0802, 0.2872, 0.3128],
         [0.6044, 0.2495, 0.0390, 0.1659, 0.2114, 0.3447, 0.4745, 0.5479]],

        [[0.2709, 0.0398, 0.4349, 0.0498, 0.0941, 0.3699, 0.5403, 0.4400],
         [0.7159, 0.6252, 0.4492, 0.7106, 0.7408, 0.2748, 0.2972, 0.4925],
         [0.3851, 0.5145, 0.3606, 0.3090, 0.1279, 0.5154, 0.5127, 0.1469]]])

In [85]:
classical_module = MultiHeadAttentionClassical(embed_dim=EMBED_DIM, num_heads=1, dropout=0.0)
#quantum_module = MultiHeadAttentionQuantum(embed_dim=EMBED_DIM, num_heads=4, dropout=0.0, n_qubits=EMBED_DIM, q_device="cuda", session=session)

In [105]:
output = classical_module(test_input)

In [106]:
output.shape

torch.Size([2, 3, 8])

In [107]:
-0.2**(1/3)

-0.5848035476425733

In [108]:
quantum_module = MultiHeadAttentionQuantum(embed_dim=EMBED_DIM, num_heads=1, dropout=0.0, n_qubits=EMBED_DIM, q_device="cpu", nb_shots=1024, coeff_amp=5)

In [109]:
output_q = quantum_module(test_input)

[tensor([-0.1445, -0.2227]), tensor([-0.4277, -0.3965]), tensor([0.0215, 0.0059]), tensor([0.4629, 0.5801]), tensor([-0.3418, -0.3574]), tensor([0.1934, 0.2461]), tensor([0.2715, 0.3027]), tensor([0.5469, 0.7734])]
[tensor([-0.1719, -0.2246]), tensor([-0.3867, -0.2734]), tensor([-0.0449, -0.0371]), tensor([0.4863, 0.6133]), tensor([-0.3750, -0.4004]), tensor([0.2852, 0.0527]), tensor([0.2812, 0.3652]), tensor([0.5879, 0.7246])]
[tensor([-0.2598, -0.1445]), tensor([-0.3438, -0.4277]), tensor([-0.0371, -0.0215]), tensor([0.5352, 0.5488]), tensor([-0.3730, -0.2969]), tensor([0.1348, 0.2793]), tensor([0.1934, 0.2070]), tensor([0.6211, 0.7188])]
[tensor([[-0.1445, -0.4277,  0.0215,  0.4629, -0.3418,  0.1934,  0.2715,  0.5469],
        [-0.2227, -0.3965,  0.0059,  0.5801, -0.3574,  0.2461,  0.3027,  0.7734]]), tensor([[-0.1719, -0.3867, -0.0449,  0.4863, -0.3750,  0.2852,  0.2812,  0.5879],
        [-0.2246, -0.2734, -0.0371,  0.6133, -0.4004,  0.0527,  0.3652,  0.7246]]), tensor([[-0.2598, 

RuntimeError: STOP

In [249]:
with Session(backend=backend) as session:
    #quantum_module = MultiHeadAttentionQuantum(embed_dim=EMBED_DIM, num_heads=1, dropout=0.0, n_qubits=EMBED_DIM, q_device="cpu", session=session)
    output_q = quantum_module(test_input, session=session)

  job = estimator.run(circuits=[tq2qiskit(q_device=q_dev, m=quantum_layer, x=torch.unsqueeze(x[i], dim=0)) for o in range(len(observables))],


[tensor([[ 0.0098, -0.0957, -0.8359, -0.1680],
        [-0.0117, -0.0625, -0.8262,  0.0918]]), tensor([[-0.0312, -0.0566, -0.7070,  0.3027],
        [ 0.0254, -0.0234, -0.6504, -0.0742]]), tensor([[-0.0449, -0.1055, -0.5586,  0.3535],
        [ 0.0547, -0.0605, -0.7441, -0.1270]])]
[tensor([[-0.0664],
        [-0.2734]]), tensor([[-0.2441],
        [-0.0391]]), tensor([[-0.0605],
        [-0.2148]])]
[tensor([[ 0.0410],
        [-0.0293]]), tensor([[-0.0430],
        [ 0.0293]]), tensor([[-0.0215],
        [ 0.1094]])]
tensor([[[ 0.0098, -0.0957, -0.8359, -0.1680],
         [-0.0312, -0.0566, -0.7070,  0.3027],
         [-0.0449, -0.1055, -0.5586,  0.3535]],

        [[-0.0117, -0.0625, -0.8262,  0.0918],
         [ 0.0254, -0.0234, -0.6504, -0.0742],
         [ 0.0547, -0.0605, -0.7441, -0.1270]]])
tensor([[0.9946-0.1041j, 0.9274-0.3742j, 0.9955-0.0950j],
        [0.9092-0.4164j, 0.9981-0.0613j, 0.9436-0.3311j]])
tensor([[0.9979-0.0644j, 0.9977+0.0674j, 0.9994+0.0337j],
        [0.998

RuntimeError: STOP