In [1]:
from qiskit_ibm_runtime import QiskitRuntimeService, Options, Sampler, Session, Estimator

from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

from torchquantum.measurement import expval_joint_analytical, expval_joint_sampling, expval_joint_sampling_grouping

In [2]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchquantum as tq

import torchquantum.functional as tqf

from torchquantum.plugin.qiskit import tq2qiskit


In [3]:
from dotenv import dotenv_values

config = dotenv_values(".env")

In [4]:
service = QiskitRuntimeService(channel="ibm_quantum", token=config["IBM_TOKEN"])

In [5]:
service.backends()

[<IBMBackend('simulator_statevector')>,
 <IBMBackend('simulator_stabilizer')>,
 <IBMBackend('ibm_kyoto')>,
 <IBMBackend('ibm_osaka')>,
 <IBMBackend('ibmq_qasm_simulator')>,
 <IBMBackend('simulator_mps')>,
 <IBMBackend('ibm_brisbane')>,
 <IBMBackend('ibm_sherbrooke')>,
 <IBMBackend('simulator_extended_stabilizer')>]

In [6]:
backend = service.backend("ibmq_qasm_simulator")

----

In [7]:
class MultiHeadAttentionBase(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 dropout: float = 0.1,
                 mask=None,
                 use_bias=False):
        super(MultiHeadAttentionBase, self).__init__()

        assert embed_dim % num_heads == 0, f"Embedding dimension ({embed_dim}) should be divisible by number of heads ({num_heads})"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads  # projection dimensions
        self.k_linear = None
        self.q_linear = None
        self.v_linear = None
        self.combine_heads = None
        self.dropout = nn.Dropout(dropout)
        self.attn_weights = None
    
    def separate_heads(self, x):
        '''
        split into N heads
        from (batch_size, seq_len, embed_dim)
        to   (batch_size, seq_len, num_heads, embed_dim)
        then transpose (1,2) to (batch_size, num_heads, seq_len, embed_dim)
        to make mat mult straightforward for each head
        '''
        batch_size = x.size(0)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def attention(self, query, key, value, mask=None, dropout=None):
        '''
        Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k))V
        '''
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        # see also: https://tensorchiefs.github.io/dlday2018/tutorial/einsum.html
        #scores = torch.einsum('bijh, bkjh -> bikh', query, key) / math.sqrt(self.d_k)
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        attn = torch.matmul(scores, value)
        return attn, scores
    
    def downstream(self, query, key, value, batch_size, mask=None):
        Q = self.separate_heads(query)
        K = self.separate_heads(key)
        V = self.separate_heads(value)

        x, self.attn_weights = self.attention(Q, K, V, mask, dropout=self.dropout)

        concat = x.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

        return concat
        # output = self.combine_heads(concat)
        # return output

   # def forward(self, x, mask=None):
    #    raise NotImplementedError("Base class does not execute forward function.")


In [8]:
class MultiHeadAttentionClassical(MultiHeadAttentionBase):
    
    def __init__(self, embed_dim: int,
                 num_heads: int,
                 dropout=0.1,
                 mask=None,
                 use_bias=False):
        super(MultiHeadAttentionClassical, self).__init__(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, mask=mask, use_bias=use_bias)

        self.k_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.q_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.v_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.combine_heads = nn.Linear(embed_dim, embed_dim, bias=use_bias)
    
    def forward(self, x, mask=None):
        batch_size, _ , embed_dim = x.size()
        assert embed_dim == self.embed_dim, f"Input embedding ({embed_dim}) does not match layer embedding size ({self.embed_dim})"

        K = self.k_linear(x)
        Q = self.q_linear(x)
        V = self.v_linear(x)

        x = self.downstream(Q, K, V, batch_size, mask)
        output = self.combine_heads(x)
        return output

In [9]:
class QLayer(tq.QuantumModule):
        def __init__(self, n_qbits, D_ansatz=1, *args, **kwargs):
            super().__init__()    
            self.n_wires = n_qbits
            self.encoder = tq.GeneralEncoder(
                    [{'input_idx': [i], 'func': 'rx', 'wires': [i]} for i in range(self.n_wires)])
            #self.rx_list = [tq.RX(has_params=True, trainable=True) for _ in range(self.n_wires)]
            #self.ry_test = tq.RY(has_params=True, trainable=True)
            #self.measure = tq.MeasureAll(tq.PauliZ)

            if n_qbits >= 2:
                self.rx_0 = tq.RX(has_params=True, trainable=True)
                self.rx_1 = tq.RX(has_params=True, trainable=True)
                self.ry_1_0 = tq.RY(has_params=True, trainable=True)
                self.ry_1_1 = tq.RY(has_params=True, trainable=True)
                self.ry_2_0 = tq.RY(has_params=True, trainable=True)
                self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 4:
                self.rx_2 = tq.RX(has_params=True, trainable=True)
                self.rx_3 = tq.RX(has_params=True, trainable=True)
                self.ry_1_2 = tq.RY(has_params=True, trainable=True)
                self.ry_1_3 = tq.RY(has_params=True, trainable=True)
                self.ry_2_2 = tq.RY(has_params=True, trainable=True)
                self.ry_2_3 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 8:
                self.rx_4 = tq.RX(has_params=True, trainable=True)
                self.rx_5 = tq.RX(has_params=True, trainable=True)
                self.rx_6 = tq.RX(has_params=True, trainable=True)
                self.rx_7 = tq.RX(has_params=True, trainable=True)
                self.ry_1_4 = tq.RY(has_params=True, trainable=True)
                self.ry_1_5 = tq.RY(has_params=True, trainable=True)
                self.ry_1_6 = tq.RY(has_params=True, trainable=True)
                self.ry_1_7 = tq.RY(has_params=True, trainable=True)
                self.ry_2_4 = tq.RY(has_params=True, trainable=True)
                self.ry_2_5 = tq.RY(has_params=True, trainable=True)
                self.ry_2_6 = tq.RY(has_params=True, trainable=True)
                self.ry_2_7 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 16:
                self.rx_8 = tq.RX(has_params=True, trainable=True)
                self.rx_9 = tq.RX(has_params=True, trainable=True)
                self.rx_10 = tq.RX(has_params=True, trainable=True)
                self.rx_11 = tq.RX(has_params=True, trainable=True)
                self.rx_12 = tq.RX(has_params=True, trainable=True)
                self.rx_13 = tq.RX(has_params=True, trainable=True)
                self.rx_14 = tq.RX(has_params=True, trainable=True)
                self.rx_15 = tq.RX(has_params=True, trainable=True)
                self.ry_1_8 = tq.RY(has_params=True, trainable=True)
                self.ry_1_9 = tq.RY(has_params=True, trainable=True)
                self.ry_1_10 = tq.RY(has_params=True, trainable=True)
                self.ry_1_11 = tq.RY(has_params=True, trainable=True)
                self.ry_1_12 = tq.RY(has_params=True, trainable=True)
                self.ry_1_13 = tq.RY(has_params=True, trainable=True)
                self.ry_1_14 = tq.RY(has_params=True, trainable=True)
                self.ry_1_15 = tq.RY(has_params=True, trainable=True)
                self.ry_2_8 = tq.RY(has_params=True, trainable=True)
                self.ry_2_9 = tq.RY(has_params=True, trainable=True)
                self.ry_2_10 = tq.RY(has_params=True, trainable=True)
                self.ry_2_11 = tq.RY(has_params=True, trainable=True)
                self.ry_2_12 = tq.RY(has_params=True, trainable=True)
                self.ry_2_13 = tq.RY(has_params=True, trainable=True)
                self.ry_2_14 = tq.RY(has_params=True, trainable=True)
                self.ry_2_15 = tq.RY(has_params=True, trainable=True)
            if n_qbits >= 32:
                self.rx_16 = tq.RX(has_params=True, trainable=True)
                self.rx_17 = tq.RX(has_params=True, trainable=True)
                self.rx_18 = tq.RX(has_params=True, trainable=True)
                self.rx_19 = tq.RX(has_params=True, trainable=True)
                self.rx_20 = tq.RX(has_params=True, trainable=True)
                self.rx_21 = tq.RX(has_params=True, trainable=True)
                self.rx_22 = tq.RX(has_params=True, trainable=True)
                self.rx_23 = tq.RX(has_params=True, trainable=True)
                self.rx_24 = tq.RX(has_params=True, trainable=True)
                self.rx_25 = tq.RX(has_params=True, trainable=True)
                self.rx_26 = tq.RX(has_params=True, trainable=True)
                self.rx_27 = tq.RX(has_params=True, trainable=True)
                self.rx_28 = tq.RX(has_params=True, trainable=True)
                self.rx_29 = tq.RX(has_params=True, trainable=True)
                self.rx_30 = tq.RX(has_params=True, trainable=True)
                self.rx_31 = tq.RX(has_params=True, trainable=True)
                self.ry_1_16 = tq.RY(has_params=True, trainable=True)
                self.ry_1_17 = tq.RY(has_params=True, trainable=True)
                self.ry_1_18 = tq.RY(has_params=True, trainable=True)
                self.ry_1_19 = tq.RY(has_params=True, trainable=True)
                self.ry_1_20 = tq.RY(has_params=True, trainable=True)
                self.ry_1_21 = tq.RY(has_params=True, trainable=True)
                self.ry_1_22 = tq.RY(has_params=True, trainable=True)
                self.ry_1_23 = tq.RY(has_params=True, trainable=True)
                self.ry_1_24 = tq.RY(has_params=True, trainable=True)
                self.ry_1_25 = tq.RY(has_params=True, trainable=True)
                self.ry_1_26 = tq.RY(has_params=True, trainable=True)
                self.ry_1_27 = tq.RY(has_params=True, trainable=True)
                self.ry_1_28 = tq.RY(has_params=True, trainable=True)
                self.ry_1_29 = tq.RY(has_params=True, trainable=True)
                self.ry_1_30 = tq.RY(has_params=True, trainable=True)
                self.ry_1_31 = tq.RY(has_params=True, trainable=True)
                self.ry_2_16 = tq.RY(has_params=True, trainable=True)
                self.ry_2_17 = tq.RY(has_params=True, trainable=True)
                self.ry_2_18 = tq.RY(has_params=True, trainable=True)
                self.ry_2_19 = tq.RY(has_params=True, trainable=True)
                self.ry_2_20 = tq.RY(has_params=True, trainable=True)
                self.ry_2_21 = tq.RY(has_params=True, trainable=True)
                self.ry_2_22 = tq.RY(has_params=True, trainable=True)
                self.ry_2_23 = tq.RY(has_params=True, trainable=True)
                self.ry_2_24 = tq.RY(has_params=True, trainable=True)
                self.ry_2_25 = tq.RY(has_params=True, trainable=True)
                self.ry_2_26 = tq.RY(has_params=True, trainable=True)
                self.ry_2_27 = tq.RY(has_params=True, trainable=True)
                self.ry_2_28 = tq.RY(has_params=True, trainable=True)
                self.ry_2_29 = tq.RY(has_params=True, trainable=True)
                self.ry_2_30 = tq.RY(has_params=True, trainable=True)
                self.ry_2_31 = tq.RY(has_params=True, trainable=True)

            #self.observables = SparsePauliOp.from_list([("ZX" + "I"*(n_qbits - 2),1), ("XY" + "I"*(n_qbits - 2),1)] + [("I"*i + "Z" + "I"*(n_qbits - 1 -i), 1) for i in range(n_qbits)])
            #self.measure = tq.MeasureAll(tq.PauliZ)
            #self.measure = tq.MeasureMultiPauliSum(self.observables)

        def ansatz_gate_forward_rx(self, q_device):
            if self.n_wires >= 2:
                self.rx_0(q_device, wires=0)
                self.rx_1(q_device, wires=1)
            if self.n_wires >= 4:
                self.rx_2(q_device, wires=2)
                self.rx_3(q_device, wires=3)
            if self.n_wires >= 8:
                self.rx_4(q_device, wires=4)
                self.rx_5(q_device, wires=5)
                self.rx_6(q_device, wires=6)
                self.rx_7(q_device, wires=7)
            if self.n_wires >= 16:
                self.rx_8(q_device, wires=8)
                self.rx_9(q_device, wires=9)
                self.rx_10(q_device, wires=10)
                self.rx_11(q_device, wires=11)
                self.rx_12(q_device, wires=12)
                self.rx_13(q_device, wires=13)
                self.rx_14(q_device, wires=14)
                self.rx_15(q_device, wires=15)
            if self.n_wires >= 32:
                self.rx_16(q_device, wires=16)
                self.rx_17(q_device, wires=17)
                self.rx_18(q_device, wires=18)
                self.rx_19(q_device, wires=19)
                self.rx_20(q_device, wires=20)
                self.rx_21(q_device, wires=21)
                self.rx_22(q_device, wires=22)
                self.rx_23(q_device, wires=23)
                self.rx_24(q_device, wires=24)
                self.rx_25(q_device, wires=25)
                self.rx_26(q_device, wires=26)
                self.rx_27(q_device, wires=27)
                self.rx_28(q_device, wires=26)
                self.rx_29(q_device, wires=29)
                self.rx_30(q_device, wires=30)
                self.rx_31(q_device, wires=31)
            
        def ansatz_gate_forward_ry_1(self, q_device):
            if self.n_wires >= 2:
                self.ry_1_0(q_device, wires=0)
                self.ry_1_1(q_device, wires=1)
            if self.n_wires >= 4:
                self.ry_1_2(q_device, wires=2)
                self.ry_1_3(q_device, wires=3)
            if self.n_wires >= 8:
                self.ry_1_4(q_device, wires=4)
                self.ry_1_5(q_device, wires=5)
                self.ry_1_6(q_device, wires=6)
                self.ry_1_7(q_device, wires=7)
            if self.n_wires >= 16:
                self.ry_1_8(q_device, wires=8)
                self.ry_1_9(q_device, wires=9)
                self.ry_1_10(q_device, wires=10)
                self.ry_1_11(q_device, wires=11)
                self.ry_1_12(q_device, wires=12)
                self.ry_1_13(q_device, wires=13)
                self.ry_1_14(q_device, wires=14)
                self.ry_1_15(q_device, wires=15)
            if self.n_wires >= 32:
                self.ry_1_16(q_device, wires=16)
                self.ry_1_17(q_device, wires=17)
                self.ry_1_18(q_device, wires=18)
                self.ry_1_19(q_device, wires=19)
                self.ry_1_20(q_device, wires=20)
                self.ry_1_21(q_device, wires=21)
                self.ry_1_22(q_device, wires=22)
                self.ry_1_23(q_device, wires=23)
                self.ry_1_24(q_device, wires=24)
                self.ry_1_25(q_device, wires=25)
                self.ry_1_26(q_device, wires=26)
                self.ry_1_27(q_device, wires=27)
                self.ry_1_28(q_device, wires=26)
                self.ry_1_29(q_device, wires=29)
                self.ry_1_30(q_device, wires=30)
                self.ry_1_31(q_device, wires=31)

        def ansatz_gate_forward_ry_2(self, q_device):
            if self.n_wires >= 2:
                self.ry_2_0(q_device, wires=0)
                self.ry_2_1(q_device, wires=1)
            if self.n_wires >= 4:
                self.ry_2_2(q_device, wires=2)
                self.ry_2_3(q_device, wires=3)
            if self.n_wires >= 8:
                self.ry_2_4(q_device, wires=4)
                self.ry_2_5(q_device, wires=5)
                self.ry_2_6(q_device, wires=6)
                self.ry_2_7(q_device, wires=7)
            if self.n_wires >= 16:
                self.ry_2_8(q_device, wires=8)
                self.ry_2_9(q_device, wires=9)
                self.ry_2_10(q_device, wires=10)
                self.ry_2_11(q_device, wires=11)
                self.ry_2_12(q_device, wires=12)
                self.ry_2_13(q_device, wires=13)
                self.ry_2_14(q_device, wires=14)
                self.ry_2_15(q_device, wires=15)
            if self.n_wires >= 32:
                self.ry_2_16(q_device, wires=16)
                self.ry_2_17(q_device, wires=17)
                self.ry_2_18(q_device, wires=18)
                self.ry_2_19(q_device, wires=19)
                self.ry_2_20(q_device, wires=20)
                self.ry_2_21(q_device, wires=21)
                self.ry_2_22(q_device, wires=22)
                self.ry_2_23(q_device, wires=23)
                self.ry_2_24(q_device, wires=24)
                self.ry_2_25(q_device, wires=25)
                self.ry_2_26(q_device, wires=26)
                self.ry_2_27(q_device, wires=27)
                self.ry_2_28(q_device, wires=26)
                self.ry_2_29(q_device, wires=29)
                self.ry_2_30(q_device, wires=30)
                self.ry_2_31(q_device, wires=31)
        
        @tq.static_support
        def forward(self, q_device, x, return_q_device=False):
            self.encoder(q_device, x)

            self.ansatz_gate_forward_rx(q_device)
            self.ansatz_gate_forward_ry_1(q_device)

            for k in range(self.n_wires):
                if k==self.n_wires-1:
                    tqf.cnot(q_device, wires=[k, 0], static=self.static_mode, parent_graph=self.graph) 
                else:
                    tqf.cnot(q_device, wires=[k, k+1], static=self.static_mode, parent_graph=self.graph)

            self.ansatz_gate_forward_ry_2(q_device)

            q_device = q_device.bfloat16()
            
            if return_q_device:
                return q_device


In [401]:
a = torch.exp((torch.pi / 2) * torch.complex(torch.tensor([0, 0], dtype=torch.float32),torch.tensor([-1, 0], dtype=torch.float32)))
b = torch.exp(- (torch.pi / 2) * torch.complex(torch.tensor([0, 0], dtype=torch.float32),torch.tensor([1, 1], dtype=torch.float32)))

In [402]:
a

tensor([-4.3711e-08-1.j,  1.0000e+00+0.j])

In [403]:
b

tensor([-4.3711e-08-1.j, -4.3711e-08-1.j])

In [190]:
torch.zeros(2)

tensor([0., 0.])

In [427]:
a = torch.randn(1, 2, device='cuda')
b = torch.randn(1, 2, device='cuda')
c = torch.tensordot(a, b, dims=2).cuda()

In [428]:
a

tensor([[-1.6300, -1.2217]], device='cuda:0')

In [429]:
b

tensor([[0.1140, 1.2924]], device='cuda:0')

In [436]:
k = torch.tensor([[1, 2, 3]])

q = torch.tensor([[4, 5, 6]])

q * k.T

tensor([[ 4,  5,  6],
        [ 8, 10, 12],
        [12, 15, 18]])

In [439]:
v = torch.tensor([[-1, 0, 1],[-2, 0, 2],[-3, 0, 3]])
v

tensor([[-1,  0,  1],
        [-2,  0,  2],
        [-3,  0,  3]])

In [440]:
q * k.T * v

tensor([[ -4,   0,   6],
        [-16,   0,  24],
        [-36,   0,  54]])

In [430]:
a * b.T

tensor([[-0.1859, -0.1393],
        [-2.1066, -1.5789]], device='cuda:0')

In [415]:
torch.tensordot(a, b, dims=([1, 0], [0,1]))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [10]:
class MultiHeadAttentionQuantum(MultiHeadAttentionBase):
    
            
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 dropout=0.1,
                 mask=None,
                 use_bias=False,
                 n_qubits: int = 4,
                 n_qlayers: int = 1,
                 nb_shots=1024,
                 q_device="default.qubit"):
        super(MultiHeadAttentionQuantum, self).__init__(embed_dim, num_heads, dropout=dropout, mask=mask, use_bias=use_bias)
        
        # todo: add intermediate layer to "dress" quantum circuit
        assert n_qubits == embed_dim, "Number of qubits ({n_qubits}) does not match embedding dim ({embed_dim})"
        self.n_qubits = n_qubits
        #self.n_qlayers = n_qlayers

        self.k_observables = ["ZZXY" + "I"*(n_qubits-4)]
        self.q_observables = ["ZZXY" + "I"*(n_qubits-4)]
        self.v_observables = ["I"*i + "ZX"+ "I"*(n_qubits - 2 - i) for i in range(n_qubits - 1)] + ["XX" + "I"*(n_qubits - 2)]

        #self.k_observables_reversed = ["".join(reversed(obs)) for obs in self.k_observables]
        #self.q_observables_reversed = ["".join(reversed(obs)) for obs in self.q_observables]
        #self.v_observables_reversed = ["".join(reversed(obs)) for obs in self.v_observables]

        self.k_layer = QLayer(n_qubits)
        self.q_layer = QLayer(n_qubits)
        self.v_layer = QLayer(n_qubits)
        #self.measure = tq.MeasureAll(tq.PauliZ)
        self.q_device = q_device
        self.nb_shots = nb_shots

    def get_exp_from_observables(self, x, q_dev, quantum_layer, observables, session=None, return_quant_exec_time = False):

        if session is not None:
            options = Options(optimization_level=1, execution={"shots":self.nb_shots})
            estimator = Estimator(session=session, options=options)

            job = estimator.run(circuits=[tq2qiskit(q_device=q_dev, m=quantum_layer, x=x) for o in range(len(observables))],
                                     observables=observables)

            if return_quant_exec_time:
                return job.result().values, job.usage_estimation
            
            else:
                return job.result().values
            
        else:

            if len(observables) > 1:

                observables_reversed = ["".join(reversed(obs)) for obs in observables]
                re_order_dict = {}
                expval = expval_joint_sampling_grouping(qdev=quantum_layer(q_dev, x, return_q_device=True), observables=observables_reversed, n_shots_per_group=self.nb_shots)
                
                for obs, value in expval.items():
                    re_order_dict["".join(reversed(obs))] = value

                return [re_order_dict[obs] for obs in observables]
            
            else:

                observable_reversed = "".join(reversed(observables[0]))
                return expval_joint_sampling(qdev=quantum_layer(q_dev, x, return_q_device=True), observable=observable_reversed, n_shots=self.nb_shots)


    def forward(self, x, mask=None, session=None):
        batch_size, seq_len, embed_dim = x.size()
        assert embed_dim == self.embed_dim, f"Input embedding ({embed_dim}) does not match layer embedding size ({self.embed_dim})"

        v_exp_val = []
        k_exp_val = []
        q_exp_val = []

        for t in range(seq_len):

            q_dev = tq.QuantumDevice(n_wires=self.n_qubits, device=self.q_device, bsz=x.shape[0])
            v_exp_val.append(self.get_exp_from_observables(x=x[:, t, :].clone(), q_dev=q_dev, quantum_layer=self.v_layer, observables=self.v_observables, session=session))
            q_dev = tq.QuantumDevice(n_wires=self.n_qubits, device=self.q_device, bsz=x.shape[0])
            k_exp_val.append(self.get_exp_from_observables(x=x[:, t, :].clone(), q_dev=q_dev, quantum_layer=self.k_layer, observables=self.k_observables, session=session))
            q_dev = tq.QuantumDevice(n_wires=self.n_qubits, device=self.q_device, bsz=x.shape[0])
            q_exp_val.append(self.get_exp_from_observables(x=x[:, t, :].clone(), q_dev=q_dev, quantum_layer=self.q_layer, observables=self.q_observables, session=session))
        
        print(v_exp_val)
        print(k_exp_val)
        print(q_exp_val)

        V = torch.tensor(v_exp_val)
        K = torch.exp((torch.pi / 2) * torch.complex(torch.zeros(seq_len, dtype=torch.float32),torch.tensor(k_exp_val, dtype=torch.float32)))
        Q = torch.exp((- torch.pi / 2) * torch.complex(torch.zeros(seq_len, dtype=torch.float32),torch.tensor(q_exp_val, dtype=torch.float32)))

        print(V)
        print(K)
        print(Q)
        print(torch.unsqueeze(Q, dim=0) * torch.unsqueeze(K, dim=0).T)
        print((torch.unsqueeze(Q, dim=0) * torch.unsqueeze(K, dim=0).T).real @ V)

        raise RuntimeError("STOP")

            
            #print(torch.stack([torch.stack(list(expval_joint_sampling_grouping(qdev=self.v_layer(q_dev, x[:, t, :].clone(), return_q_device=True), observables=self.v_observables, n_shots_per_group=1024).values())) for t in range(seq_len)]))
            #V = [self.v_layer(q_dev, x[:, t, :].clone()).values for t in range(seq_len)]
        
        #print(V)
        
        #K = [self.q_layer(x[:, t, :].clone(),q_dev) for t in range(seq_len)]
        #Q = [self.q_layer(x[:, t, :].clone(),q_dev) for t in range(seq_len)]
        #V = [self.q_layer(x[:, t, :].clone(),q_dev) for t in range(seq_len)]

        K = torch.Tensor(pad_sequence(K))
        Q = torch.Tensor(pad_sequence(Q))
        V = torch.Tensor(pad_sequence(V))
        x = self.downstream(Q, K, V, batch_size, mask)
        #output = [self.q_layer(x[:, t, :],q_dev) for t in range(seq_len)]
        #output = torch.Tensor(pad_sequence(output)).clone()
        return x


In [11]:
EMBED_DIM = 4

BATCH_SIZE = 1

SEQ_LEN = 3


In [12]:
test_input = torch.tensor(np.random.rand(BATCH_SIZE, SEQ_LEN, EMBED_DIM), dtype=torch.float32)

In [13]:
classical_module = MultiHeadAttentionClassical(embed_dim=EMBED_DIM, num_heads=1, dropout=0.0)
#quantum_module = MultiHeadAttentionQuantum(embed_dim=EMBED_DIM, num_heads=4, dropout=0.0, n_qubits=EMBED_DIM, q_device="cuda", session=session)

In [14]:
output = classical_module(test_input)

In [15]:
output.shape

torch.Size([1, 3, 4])

In [16]:
quantum_module = MultiHeadAttentionQuantum(embed_dim=EMBED_DIM, num_heads=1, dropout=0.0, n_qubits=EMBED_DIM, q_device="cpu", nb_shots=1024)

In [17]:
output_q = quantum_module(test_input)

[[tensor([-0.2812]), tensor([-0.0527]), tensor([0.3652]), tensor([-0.3711])], [tensor([-0.3828]), tensor([0.0918]), tensor([0.4492]), tensor([-0.3867])], [tensor([-0.4512]), tensor([-0.0332]), tensor([0.3691]), tensor([-0.5098])]]
[tensor([-0.0156]), tensor([-0.0996]), tensor([-0.0449])]
[tensor([-0.0859]), tensor([-0.0527]), tensor([-0.0469])]
tensor([[-0.2812, -0.0527,  0.3652, -0.3711],
        [-0.3828,  0.0918,  0.4492, -0.3867],
        [-0.4512, -0.0332,  0.3691, -0.5098]])
tensor([0.9997-0.0245j, 0.9878-0.1558j, 0.9975-0.0705j])
tensor([0.9909+0.1346j, 0.9966+0.0827j, 0.9973+0.0736j])
tensor([[0.9939+0.1102j, 0.9983+0.0583j, 0.9988+0.0491j],
        [0.9998-0.0215j, 0.9973-0.0736j, 0.9966-0.0827j],
        [0.9979+0.0644j, 0.9999+0.0123j, 1.0000+0.0031j]])
tensor([[-1.1123,  0.0061,  1.1802, -1.2640],
        [-1.1126,  0.0057,  1.1810, -1.2647],
        [-1.1146,  0.0060,  1.1828, -1.2668]])


RuntimeError: STOP

In [277]:
with Session(backend=backend) as session:
    #quantum_module = MultiHeadAttentionQuantum(embed_dim=EMBED_DIM, num_heads=1, dropout=0.0, n_qubits=EMBED_DIM, q_device="cpu", session=session)
    output_q = quantum_module(test_input, session=session)

[32m[2024-05-12 15:03:11.838][0m [31m[1mCannot convert batch model tq module[0m
[33m[1mTraceback (most recent call last):[0m

  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jesshuan/miniconda3/envs/torch_quantum/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
    │   └ <bound method Application.launch_instance of <class 'ipykernel.kernelapp.IPKernelApp'>>
    └ <module 'ipykernel.kernelapp' from '/home/jesshuan/miniconda3/envs/torch_quantum/lib/python3.12/site-packages/ipykernel/kerne...
  File "/home/jesshuan/miniconda3/envs/torch_quantum/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
    │   └ <function IPKernelApp.start at 0x7dbc1d2fd260>
    └ <ipykernel.kernelapp.IPKernelApp object at 0x7dbc1ffb2480>
  File "/home/jesshuan/miniconda3/envs/torch_quantum/lib/python3.12/site-packages/ipyke

KeyboardInterrupt: 