In [24]:
! pip install qiskit --quiet
! pip install qiskit-aer --quiet
! pip install torchvision --quiet

In [25]:
import numpy as np
import torchvision
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import keras

from qiskit import transpile, QuantumCircuit
from qiskit_aer import AerSimulator
from qiskit.visualization import plot_histogram
from PIL import Image
from keras.datasets import mnist

In [26]:
def calc_angle(prob):
    #Calculate the angle using the formula
    return 2 * np.arccos(np.sqrt(prob))

def create_quantum_circuit(null_prior, null_probs, positive_probs):
    # Calculate angles based on the provided probabilities
    theta_y = calc_angle(null_prior)
    null_rotations = [calc_angle(conditional) for conditional in null_probs]
    pos_rotations = [calc_angle(conditional) for conditional in positive_probs]

    # Initialize the quantum circuit with n qubits and n classical bits
    assert(len(null_probs) == len(positive_probs))
    n = len(null_probs) #Number of features
    circ = QuantumCircuit(n+1) #Add one for prior

    # Implementing the gates based on the calculated thetas
    circ.ry(theta_y, 0)  # Encode P(y=0) into qubit 0
    for i in range(n):
        circ.cry(null_rotations[i], 0, i+1)
    circ.x(0)  # Flip y to represent y=1
    for i in range(n):
        circ.cry(pos_rotations[i], 0, i+1) # Now for y=1, encode P(x1|y=1)
    circ.x(0)  # Reset y back

    # Measure
    circ.measure_all()  # Measure both the label and feature qubits

    return circ

In [27]:
class BinaryPreprocess:
    def __init__(self, diffclasses: list, mnist_path: str = 'data', dd: bool = True):
        dataset = torchvision.datasets.MNIST(root=mnist_path, train=True, download=dd)

        class_features = BinaryPreprocess._processMNIST(dataset)

        self.classes = diffclasses
        self.features = class_features[diffclasses] #2 x 2 x Features]

        self.null_probs = torch.zeros(9)
        self.positive_probs = torch.zeros(9)

        diffclasses_count = 0
        prior_count = 0
        for (data, label) in dataset:
            if label in diffclasses:
                diffclasses_count += 1
                binarized_features = self.inference_features(data)
                if label == diffclasses[0]:
                    prior_count += 1
                    self.null_probs += binarized_features.squeeze()
                else:
                    self.positive_probs += binarized_features.squeeze()

        positive_count = diffclasses_count - prior_count
        self.prior = prior_count / diffclasses_count
        self.null_probs /= prior_count
        self.positive_probs /= positive_count

    def inference_features(self, img):
        """
        Binarization of img average features between the two classes
        input: img 28 x 28 image
        """
        img = torch.from_numpy(np.array(img)).float()
        img = img[1:,1:]

        avg_kernel = torch.ones((8,8)) / 64
        x_prime = F.conv2d(img[None, :], avg_kernel[None, None, :], stride=9).reshape(1, -1)
        divpoint = (self.features[0, 0] * self.features[1, 1] - self.features[1, 0] * self.features[0, 1]) / (self.features[0, 1] - self.features[1, 1])

        over_div = x_prime > divpoint
        mu_diff = self.features[0, 0] > self.features[1, 0]

        return torch.logical_xor(over_div, mu_diff).float()


    def _processMNIST(dataset):
        avg_kernel = torch.ones((8, 8)) / 64

        stats = torch.zeros((0,9))
        labels = torch.zeros(len(dataset))

        for i, (data, label) in enumerate(dataset):
            img = torch.from_numpy(np.array(data)).float()
            img = img[1:,1:]

            mean_pooled = F.conv2d(img[None, :], avg_kernel[None, None, :], stride=9).reshape(1, -1)

            stats = torch.cat([stats, mean_pooled])

            labels[i] = label

        class_features = torch.zeros((10, 2, 9)) #Classes x Statistics x Features
        for i in range(len(class_features)):
            subset = stats[labels == i]
            feat_mean = subset.mean(dim=0) #Shape should be 9 for both
            feat_std = subset.var(dim=0)

            class_features[i] = torch.stack((feat_mean, feat_std)) #class_features[c][0] -> means, [c][1] -> variance

        return class_features

In [None]:
null_class = 0
alt_class = 1

if __name__ == "__main__":

    (train_X, train_y), (test_X, test_y) = mnist.load_data()
    training_X = []
    training_y = []
    for i in range(len(train_X)):
      if train_y[i] == 1:
        training_y.append(1)
        training_X.append(train_X[i])
      if train_y[i] == 0:
        training_y.append(0)
        training_X.append(train_X[i])

    preproc = BinaryPreprocess([null_class, alt_class])

    #null prior, null probs, positive probs
    naive_qbc = create_quantum_circuit(preproc.prior, preproc.null_probs.tolist(), preproc.positive_probs.tolist())

    correctly_predicted_samples = 0
    for sample_index in range(len(training_X)):
      img = torch.from_numpy(training_X[sample_index])
      binarized_features = preproc.inference_features(img).squeeze()
      # print(binarized_features)

      pre_circ = QuantumCircuit(len(preproc.null_probs)+1)
      for i in range(len(binarized_features)):
        if binarized_features[i].item() == 1:
          pre_circ.x(i+1)

      composite = pre_circ.compose(naive_qbc)
      simulator = AerSimulator()
      shots = 3000

      composite = transpile(composite, simulator)
      result = simulator.run(composite, shots=shots).result()
      counts = result.get_counts(composite)

      # print(counts)
      matched_counts = {}
      for output in counts.keys():
        for i in range(len(binarized_features)):
          n = len(binarized_features)-1 #Output is little-endian
          if float(output[i]) != binarized_features[n-i].item():
            continue
          matched_counts[output[-1]] = counts[output]

      # print(matched_counts)
      # print(training_y[sample_index])
      y_true = training_y[sample_index]
      y_pred = int(max(zip(matched_counts.values(), matched_counts.keys()))[1])
      if (y_pred == y_true):
        correctly_predicted_samples += 1
      # plot_histogram(matched_counts, title='Output Counts')
      # plt.show()

    total_samples = len(training_X)
    print(correctly_predicted_samples)
    print(total_samples)
    accuracy = (float(correctly_predicted_samples)) / total_samples
    print(accuracy)