In [2]:
import pennylane as qml
import torch
import numpy as np
import random
import torch.nn as nn
from tqdm import tqdm

In [3]:
import functools
import inspect
import math
from collections.abc import Iterable
from typing import Callable, Dict, Union, Any

from pennylane.qnode import QNode

try:
    import torch
    from torch.nn import Module

    TORCH_IMPORTED = True
except ImportError:
    # The following allows this module to be imported even if PyTorch is not installed. Users
    # will instead see an ImportError when instantiating the TorchLayer.
    from unittest.mock import Mock

    Module = Mock
    TORCH_IMPORTED = False


class TorchLayer(Module):
    def __init__(self,qnode,weights):
        if not TORCH_IMPORTED:
            raise ImportError(
                "TorchLayer requires PyTorch. PyTorch can be installed using:\n"
                "pip install torch\nAlternatively, "
                "visit https://pytorch.org/get-started/locally/ for detailed "
                "instructions."
            )
        super().__init__()

        #weight_shapes = {
        #    weight: (tuple(size) if isinstance(size, Iterable) else () if size == 1 else (size,))
        #    for weight, size in weight_shapes.items()
        #}

        # validate the QNode signature, and convert to a Torch QNode.
        # TODO: update the docstring regarding changes to restrictions when tape mode is default.
        #self._signature_validation(qnode, weight_shapes)
        self.qnode = qnode
        self.qnode.interface = "torch"

        self.qnode_weights = weights

    def forward(self, inputs):  # pylint: disable=arguments-differ
        """Evaluates a forward pass through the QNode based upon input data and the initialized
        weights.

        Args:
            inputs (tensor): data to be processed

        Returns:
            tensor: output data
        """

        if len(inputs.shape) > 1:
            # If the input size is not 1-dimensional, unstack the input along its first dimension,
            # recursively call the forward pass on each of the yielded tensors, and then stack the
            # outputs back into the correct shape
            reconstructor = [self.forward(x) for x in torch.unbind(inputs)]
            return torch.stack(reconstructor)

        # If the input is 1-dimensional, calculate the forward pass as usual
        return self._evaluate_qnode(inputs)


    def _evaluate_qnode(self, x):
        """Evaluates the QNode for a single input datapoint.

        Args:
            x (tensor): the datapoint

        Returns:
            tensor: output datapoint
        """
        kwargs = {
            **{self.input_arg: x},
            **{arg: weight.to(x) for arg, weight in self.qnode_weights.items()},
        }
        res = self.qnode(**kwargs)

        if isinstance(res, torch.Tensor):
            return res.type(x.dtype)

        return torch.hstack(res).type(x.dtype)

    def __str__(self):
        detail = "<Quantum Torch Layer: func={}>"
        return detail.format(self.qnode.func.__name__)

    __repr__ = __str__
    _input_arg = "inputs"

    @property
    def input_arg(self):
        """Name of the argument to be used as the input to the Torch layer. Set to ``"inputs"``."""
        return self._input_arg

In [4]:
class QSAL_pennylane(torch.nn.Module):
    def __init__(self,S,n,Denc,D):
        super().__init__()
        self.seq_num=S
        self.init_params_Q=torch.nn.Parameter(torch.stack([(np.pi/4) * (2 * torch.randn(n*(D+2)) - 1) for _ in range(self.seq_num)]))
        self.init_params_K=torch.nn.Parameter(torch.stack([(np.pi/4) * (2 * torch.randn(n*(D+2)) - 1) for _ in range(self.seq_num)]))
        self.init_params_V=torch.nn.Parameter(torch.stack([(np.pi/4) * (2 * torch.randn(n*(D+2)) - 1) for _ in range(self.seq_num)]))
        self.num_q=n
        self.Denc=Denc
        self.D=D
        self.d=n*(Denc+2)
        self.dev = qml.device("default.qubit", wires=self.num_q)
        
        self.vqnod=qml.QNode(self.circuit_v, self.dev, interface="torch")
        self.qnod=qml.QNode(self.circuit_qk, self.dev, interface="torch")
        self.weight_v = [{"weights": self.init_params_V[i]} for i in range(self.seq_num)]
        self.weight_q = [{"weights": self.init_params_Q[i]} for i in range(self.seq_num)]
        self.weight_k = [{"weights": self.init_params_K[i]} for i in range(self.seq_num)]
        #self.v_linear ={} #[qml.qnn.TorchLayer(self.vqnod[i], self.weight_shapes) for i in range(self.seq_num)]
        #for i in range(self.seq_num):
        self.v_linear = [TorchLayer(self.vqnod, self.weight_v[i]) for i in range(self.seq_num)]
        self.q_linear = [TorchLayer(self.qnod, self.weight_q[i]) for i in range(self.seq_num)]
        self.k_linear = [TorchLayer(self.qnod, self.weight_k[i]) for i in range(self.seq_num)]
        #self.qqnod=[qml.QNode(self.circuit_qk, self.dev, interface="torch") for i in range(self.seq_num)]

    def random_op(self):
        a=random.randint(0, 4)
        if a==0:
            op=qml.Identity(0)
        elif a==1:
            op=qml.PauliX(0)
        elif a==2:
            op=qml.PauliY(0)
        else:
            op=qml.PauliZ(0)

        op_elimated=qml.Identity(0)
        for i in range(1,self.num_q):
            op_elimated=op_elimated@qml.Identity(i)
        Select_wrong=True
        while Select_wrong:
            for i in range(1,self.num_q):
                a=random.randint(0, 4)
                if a==0:
                    op=op@qml.Identity(i)
                elif a==1:
                    op=op@qml.PauliX(i)
                elif a==2:
                    op=op@qml.PauliY(i)
                else:
                    op=op@qml.PauliZ(i)
            if op!=op_elimated:
                Select_wrong=False
        return op

    def circuit_v(self,inputs,weights):
            op=self.random_op()
            # feature_map
            indx=0
            for j in range(self.num_q):
                qml.RX(inputs[indx],j)
                qml.RY(inputs[indx+1],j)
                indx+=2
            for i in range(self.Denc):
                for j in range(self.num_q):
                    qml.CNOT(wires=(j,(j+1)%self.num_q))

                for j in range(self.num_q):
                    qml.RY(inputs[indx],j)
                    indx+=1
            # Ansatz
            indx=0
            for j in range(self.num_q):
                qml.RX(weights[indx],j)
                qml.RY(weights[indx+1],j)
                indx+=2
            for i in range(self.D):
                for j in range(self.num_q):
                    qml.CNOT(wires=(j,(j+1)%self.num_q))
                    
                for j in range(self.num_q):
                    #qc.rx(params[indx],j)
                    qml.RY(weights[indx],j)
                    indx+=1
            return [qml.expval(op) for i in range(self.d)] 

    def circuit_qk(self,inputs,weights):
        op=self.random_op()
        # feature_map
        indx=0
        for j in range(self.num_q):
            qml.RX(inputs[indx],j)
            qml.RY(inputs[indx+1],j)
            indx+=2
        for i in range(self.Denc):
            for j in range(self.num_q):
                qml.CNOT(wires=(j,(j+1)%self.num_q))

            for j in range(self.num_q):
                qml.RY(inputs[indx],j)
                indx+=1
        # Ansatz
        indx=0
        for j in range(self.num_q):
            qml.RX(weights[indx],j)
            qml.RY(weights[indx+1],j)
            indx+=2
        for i in range(self.D):
            for j in range(self.num_q):
                qml.CNOT(wires=(j,(j+1)%self.num_q))
                
            for j in range(self.num_q):
                #qc.rx(params[indx],j)
                qml.RY(weights[indx],j)
                indx+=1
        return [qml.expval(qml.PauliZ(0))]

    def forward(self,input):

        Q_output=torch.stack([self.q_linear[i](input[:,i]) for i in range(self.seq_num)])
        K_output=torch.stack([self.k_linear[i](input[:,i]) for i in range(self.seq_num)])
        V_output=torch.stack([self.v_linear[i](input[:,i]) for i in range(self.seq_num)])
        
        batch_size=len(input)
        Q_output=Q_output.transpose(0,2).repeat((self.seq_num,1,1))
        K_output=K_output.transpose(0,2).repeat((self.seq_num,1,1)).transpose(0,2)
        #print(V_output.size())
        #Q_grid, K_grid=torch.meshgrid(Q_output, K_output, indexing='ij')
        alpha=torch.exp(-(Q_output-K_output)**2)
        alpha=alpha.transpose(0,1)
        V_output=V_output.transpose(0,1)
        output=[]

        for i in range(self.seq_num):
            
            Sum_a=torch.sum(alpha[:,i,:],-1)
            div_sum_a=(1/Sum_a).repeat(self.d,self.seq_num,1).transpose(0,2)
            
            Sum_w=torch.sum(alpha[:,:,i].repeat((self.d,1,1)).transpose(0,2).transpose(0,1)*V_output*div_sum_a,1)
            output.append(Sum_w)
        return input+torch.stack(output).transpose(0,1)

class QSANN_pennylane(torch.nn.Module):
    def __init__(self,S,n,Denc,D,num_layers):
        """
        # input: input data
        # weight: trainable parameter
        # n: # of of qubits
        # d: embedding dimension which is equal to n(Denc+2)
        # Denc: the # number of layers for encoding 
        # D: the # of layers of variational layers
        # type "K": key, "Q": Query, "V": value
        """
        super().__init__()
        self.qsal_lst=[QSAL_pennylane(S,n,Denc,D) for _ in range(num_layers)]
        self.qnn=nn.Sequential(*self.qsal_lst)

    def forward(self,input):

        return self.qnn(input)

class QSANN_text_classifier(torch.nn.Module):
    def __init__(self,S,n,Denc,D,num_layers):
        """
        # input: input data
        # weight: trainable parameter
        # n: # of of qubits
        # d: embedding dimension which is equal to n(Denc+2)
        # Denc: the # number of layers for encoding 
        # D: the # of layers of variational layers
        # type "K": key, "Q": Query, "V": value
        """
        super().__init__()
        self.Qnn=QSANN_pennylane(S,n,Denc,D,num_layers)
        self.final_layer=nn.Linear(n*(Denc+2)*S, 1)
        self.final_layer=self.final_layer.float()

    def forward(self,input):

        x=self.Qnn(input)
        x=torch.flatten(x,start_dim=1)
        
        return torch.sigmoid(self.final_layer(x))

In [5]:
model=QSANN_text_classifier(4, 4, 2, 1, 1)

# (4, 16)

# seq * (num_qubits) * (num_layers + 2)

In [6]:
model

QSANN_text_classifier(
  (Qnn): QSANN_pennylane(
    (qnn): Sequential(
      (0): QSAL_pennylane()
    )
  )
  (final_layer): Linear(in_features=64, out_features=1, bias=True)
)

# Binary Classification with Sklearn Image Dataset

# Sklearn Image Dataset (Patches prepared Row-wise)

In [109]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
digits = load_digits()
X, y = digits.images, digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

train_mask = np.isin(y_train, [1, 7])
X_train, y_train = X_train[train_mask], y_train[train_mask]

test_mask = np.isin(y_test, [1, 7])
X_test, y_test = X_test[test_mask], y_test[test_mask]

X_train = X_train.reshape(X_train.shape[0], 4, 16)
X_test = X_test.reshape(X_test.shape[0], 4, 16)

In [110]:
X_train.shape

(273, 4, 16)

In [111]:
model=QSANN_text_classifier(4, 4, 2, 1, 1)

# seq * (num_qubits) * (num_layers + 2)

In [112]:
model

QSANN_text_classifier(
  (Qnn): QSANN_pennylane(
    (qnn): Sequential(
      (0): QSAL_pennylane()
    )
  )
  (final_layer): Linear(in_features=64, out_features=1, bias=True)
)

In [113]:
optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())

In [114]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(trainable_params)

209


In [115]:
criterion = torch.nn.CrossEntropyLoss()

In [116]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = (torch.round(torch.sign(preds-0.5))+1)//2
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

In [117]:
for iepoch in tqdm(range(30)):
    optimizer.zero_grad()
    X_tensor=torch.tensor(X_train)
    predictions=model(X_tensor.float()).squeeze(1)
    #predictions=torch.sign(predictions)
    #print(predictions)
    label=torch.tensor(y_train)
    for i in range(len(label)):
        if label[i]==1:
            label[i] = 0
        else:
            label[i]=1
    #print(label)
    loss = criterion(predictions, label.float())
    acc = binary_accuracy(predictions, label)
    print('')
    print('Accuracy:',acc)
    print('')
    print(loss)
    loss.backward()
    optimizer.step()


  0%|          | 0/30 [00:00<?, ?it/s]


Accuracy: tensor(0.4945)

tensor(778.4664, grad_fn=<DivBackward1>)



  3%|▎         | 1/30 [01:28<42:55, 88.80s/it]


Accuracy: tensor(0.5971)

tensor(763.2916, grad_fn=<DivBackward1>)



  7%|▋         | 2/30 [03:02<42:53, 91.92s/it]


Accuracy: tensor(0.6447)

tensor(756.7378, grad_fn=<DivBackward1>)



 10%|█         | 3/30 [04:34<41:18, 91.81s/it]


Accuracy: tensor(0.6923)

tensor(751.3929, grad_fn=<DivBackward1>)



 13%|█▎        | 4/30 [06:05<39:41, 91.58s/it]


Accuracy: tensor(0.7436)

tensor(745.1588, grad_fn=<DivBackward1>)



 17%|█▋        | 5/30 [07:38<38:18, 91.95s/it]


Accuracy: tensor(0.8242)

tensor(738.8176, grad_fn=<DivBackward1>)



 20%|██        | 6/30 [09:10<36:46, 91.94s/it]


Accuracy: tensor(0.8571)

tensor(735.7882, grad_fn=<DivBackward1>)



 23%|██▎       | 7/30 [10:41<35:12, 91.84s/it]


Accuracy: tensor(0.8718)

tensor(734.6052, grad_fn=<DivBackward1>)



 27%|██▋       | 8/30 [12:14<33:42, 91.94s/it]


Accuracy: tensor(0.8718)

tensor(732.9946, grad_fn=<DivBackward1>)



 30%|███       | 9/30 [13:47<32:19, 92.34s/it]


Accuracy: tensor(0.8938)

tensor(730.7499, grad_fn=<DivBackward1>)



 33%|███▎      | 10/30 [15:19<30:48, 92.42s/it]


Accuracy: tensor(0.9121)

tensor(728.5879, grad_fn=<DivBackward1>)



 37%|███▋      | 11/30 [16:52<29:14, 92.32s/it]


Accuracy: tensor(0.9231)

tensor(726.6227, grad_fn=<DivBackward1>)



 40%|████      | 12/30 [18:25<27:45, 92.53s/it]


Accuracy: tensor(0.9267)

tensor(725.1233, grad_fn=<DivBackward1>)



 43%|████▎     | 13/30 [19:57<26:10, 92.40s/it]


Accuracy: tensor(0.9414)

tensor(723.7494, grad_fn=<DivBackward1>)



 47%|████▋     | 14/30 [21:29<24:38, 92.40s/it]


Accuracy: tensor(0.9560)

tensor(722.3354, grad_fn=<DivBackward1>)



 50%|█████     | 15/30 [23:01<23:05, 92.34s/it]


Accuracy: tensor(0.9670)

tensor(721.1648, grad_fn=<DivBackward1>)



 53%|█████▎    | 16/30 [24:35<21:39, 92.79s/it]


Accuracy: tensor(0.9780)

tensor(719.8289, grad_fn=<DivBackward1>)



 57%|█████▋    | 17/30 [26:07<20:04, 92.66s/it]


Accuracy: tensor(0.9817)

tensor(719.0279, grad_fn=<DivBackward1>)



 60%|██████    | 18/30 [27:41<18:36, 93.05s/it]


Accuracy: tensor(0.9963)

tensor(718.3773, grad_fn=<DivBackward1>)



 63%|██████▎   | 19/30 [29:17<17:11, 93.78s/it]


Accuracy: tensor(0.9927)

tensor(718.1243, grad_fn=<DivBackward1>)



 67%|██████▋   | 20/30 [30:51<15:39, 93.92s/it]


Accuracy: tensor(0.9927)

tensor(717.8967, grad_fn=<DivBackward1>)



 70%|███████   | 21/30 [32:25<14:06, 94.03s/it]


Accuracy: tensor(0.9927)

tensor(717.8036, grad_fn=<DivBackward1>)



 73%|███████▎  | 22/30 [34:01<12:35, 94.43s/it]


Accuracy: tensor(0.9927)

tensor(717.6062, grad_fn=<DivBackward1>)



 77%|███████▋  | 23/30 [35:37<11:04, 94.96s/it]


Accuracy: tensor(0.9927)

tensor(717.5109, grad_fn=<DivBackward1>)



 80%|████████  | 24/30 [37:12<09:30, 95.12s/it]


Accuracy: tensor(1.)

tensor(717.4036, grad_fn=<DivBackward1>)



 83%|████████▎ | 25/30 [38:48<07:56, 95.27s/it]


Accuracy: tensor(1.)

tensor(717.3091, grad_fn=<DivBackward1>)



 87%|████████▋ | 26/30 [40:22<06:19, 94.94s/it]


Accuracy: tensor(1.)

tensor(717.2590, grad_fn=<DivBackward1>)



 90%|█████████ | 27/30 [41:57<04:44, 94.79s/it]


Accuracy: tensor(1.)

tensor(717.2249, grad_fn=<DivBackward1>)



 93%|█████████▎| 28/30 [43:32<03:09, 94.88s/it]


Accuracy: tensor(1.)

tensor(717.2009, grad_fn=<DivBackward1>)



 97%|█████████▋| 29/30 [45:03<01:33, 93.68s/it]


Accuracy: tensor(1.)

tensor(717.1962, grad_fn=<DivBackward1>)


100%|██████████| 30/30 [46:40<00:00, 93.35s/it]


In [118]:
X_tensor=torch.tensor(X_test)
predictions=model(X_tensor.float()).squeeze(1)
label=torch.tensor(y_test)
for i in range(len(label)):
        if label[i]==1:
            label[i] = 0
        else:
            label[i]=1
loss = criterion(predictions, label.float())
acc = binary_accuracy(predictions, label.float())
print('')
print('Accuracy:',acc)
print('')
print(loss)


Accuracy: tensor(0.9886)

tensor(171.4812, grad_fn=<DivBackward1>)


# Sklearn Image Dataset (Patches Prepared Column-wise)

In [83]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
digits = load_digits()
X, y = digits.images[0:500], digits.target[0:500]
X_train, X_test, y_train, y_test = train_test_split(X, y)

train_mask = np.isin(y_train, [1, 7])
X_train, y_train = X_train[train_mask], y_train[train_mask]

test_mask = np.isin(y_test, [1, 7])
X_test, y_test = X_test[test_mask], y_test[test_mask]

#X_train = X_train.reshape(X_train.shape[0], 4, 16)
#X_test = X_test.reshape(X_test.shape[0], 4, 16)
X_train_c = X_train.transpose(0, 2, 1)
X_test_c = X_test.transpose(0, 2, 1)

X_train_c = X_train_c.reshape(X_train.shape[0], 4, 16)
X_test_c = X_test_c.reshape(X_test.shape[0], 4, 16)

In [79]:
model=QSANN_text_classifier(4,4,2,1,1)
optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [81]:
for iepoch in tqdm(range(10)):
    optimizer.zero_grad()
    X_tensor=torch.tensor(X_train_c)
    predictions=model(X_tensor.float()).squeeze(1)
    #predictions=torch.sign(predictions)
    #print(predictions)
    label=torch.tensor(y_train)
    for i in range(len(label)):
        if label[i]==1:
            label[i] = 0
        else:
            label[i]=1
    #print(label)
    loss = criterion(predictions, label.float())
    acc = binary_accuracy(predictions, label)
    print('')
    print('Accuracy:',acc)
    print('')
    print(loss)
    loss.backward()
    optimizer.step()


  0%|          | 0/10 [00:00<?, ?it/s]


Accuracy: tensor(0.9481)

tensor(146.3840, grad_fn=<DivBackward1>)



 10%|█         | 1/10 [00:24<03:43, 24.87s/it]


Accuracy: tensor(0.9481)

tensor(145.6917, grad_fn=<DivBackward1>)



 20%|██        | 2/10 [00:50<03:21, 25.24s/it]


Accuracy: tensor(0.9740)

tensor(143.9053, grad_fn=<DivBackward1>)



 30%|███       | 3/10 [01:15<02:56, 25.27s/it]


Accuracy: tensor(1.)

tensor(142.9731, grad_fn=<DivBackward1>)



 40%|████      | 4/10 [01:41<02:32, 25.37s/it]


Accuracy: tensor(0.9740)

tensor(142.8075, grad_fn=<DivBackward1>)



 50%|█████     | 5/10 [02:06<02:07, 25.52s/it]


Accuracy: tensor(0.9740)

tensor(142.7293, grad_fn=<DivBackward1>)



 60%|██████    | 6/10 [02:33<01:42, 25.74s/it]


Accuracy: tensor(0.9870)

tensor(142.6147, grad_fn=<DivBackward1>)



 70%|███████   | 7/10 [02:59<01:17, 25.88s/it]


Accuracy: tensor(0.9870)

tensor(142.3875, grad_fn=<DivBackward1>)



 80%|████████  | 8/10 [03:28<00:53, 26.82s/it]


Accuracy: tensor(0.9870)

tensor(142.2015, grad_fn=<DivBackward1>)



 90%|█████████ | 9/10 [03:53<00:26, 26.36s/it]


Accuracy: tensor(1.)

tensor(142.0063, grad_fn=<DivBackward1>)


100%|██████████| 10/10 [04:19<00:00, 25.91s/it]


In [84]:
X_tensor=torch.tensor(X_test_c)
predictions=model(X_tensor.float()).squeeze(1)
print(predictions)
label=torch.tensor(y_test)
for i in range(len(label)):
        if label[i]==1:
            label[i] = 0
        else:
            label[i]=1
print(label)
loss = criterion(predictions, label.float())
acc = binary_accuracy(predictions, label.float())
print('')
print('Accuracy:',acc)
print('')
print(loss)

tensor([5.2172e-03, 9.9299e-01, 9.9714e-01, 2.4885e-02, 3.4008e-03, 1.7592e-03,
        9.9949e-01, 3.8367e-03, 9.9953e-01, 9.9986e-01, 3.1252e-03, 8.2579e-02,
        9.9762e-01, 9.9997e-01, 3.1298e-03, 7.1565e-03, 9.9605e-01, 9.9876e-01,
        9.9963e-01, 1.4773e-03, 9.2540e-04], grad_fn=<SqueezeBackward1>)
tensor([0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0])

Accuracy: tensor(1.)

tensor(26.4663, grad_fn=<DivBackward1>)
