In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import idx2numpy
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import os
from PIL import Image
import pennylane as qml
import csv
from datetime import datetime

In [2]:
def CheckUnitary(U, tol=1e-5):
    I = torch.eye(U.shape[0], dtype=U.dtype, device=U.device)
    UU_dagger = U @ U.conj().T
    deviation = torch.norm(UU_dagger - I)
    #print(f"Deviation from unitarity: {deviation.item()}")
    return deviation < tol

In [3]:
class QuantumLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # Represent a unitary via a parameterized Hermitian matrix (U = exp(-iH))
        self.H = nn.Parameter(torch.randn(input_dim, input_dim, dtype=torch.complex128))

    def forward(self, psi):
        # Create unitary with U = exp(-iH) using matrix exponential
        H_hermitian = 0.5 * (self.H + self.H.conj().T)
        U = torch.matrix_exp(-1j * H_hermitian)

        if CheckUnitary(U):
            out_state = U @ psi.T
            #print("Unitary check passed ✅")
        else:
            out_state = U @ psi.T
            print("Unitary check failed ❌")
        
        return out_state.T

In [4]:
class AncillaLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # Controlled or entangling block
        self.H = nn.Parameter(torch.randn(2*dim, 2*dim, dtype=torch.complex128))  # Ancilla + psi

    def forward(self, psi):
        batch_size = psi.shape[0]
        dim = psi.shape[1]

        ancilla = torch.zeros(batch_size, 2, dtype=torch.complex128, device=psi.device)
        ancilla[:, 0] = 1.0  # |0⟩

        joint_state = (ancilla.unsqueeze(2) * psi.unsqueeze(1)).reshape(batch_size, 2 * dim)

        H_hermitian = 0.5 * (self.H + self.H.conj().T)
        U = torch.matrix_exp(-1j * H_hermitian)
        out_state = U @ joint_state.T
        return out_state.T

In [5]:
class Measurement(nn.Module):
    def forward(self, psi, num_ancilla):
        vector_length = psi.shape[1]
        # Extract ancilla = 0 portion
        psi_sys0 = psi[:, :vector_length//(2**num_ancilla)]  # assuming ancilla is the first qubit
        # Normalize (since you're post-selecting, it's not a full probability distribution)
        psi_sys0 = psi_sys0 / psi_sys0.norm(dim=1, keepdim=True)
        
        '''
        # Extract ancilla = 1 portion
        psi_sys1 = psi[:, vector_length//2:]  # assuming ancilla is the first qubit
        # Normalize
        psi_sys1 = psi_sys1 / psi_sys1.norm(dim=1, keepdim=True)
        '''

        # Measurement:
        probs = torch.abs(psi_sys0) ** 2
        return psi_sys0, probs

In [6]:
class QiNN(nn.Module):
    def __init__(self, input_dim, num_layers, num_ancilla):
        super().__init__()
        self.num_ancilla = num_ancilla
        self.quantum_layers_with_ancilla = nn.ModuleList([QuantumLayer((2**num_ancilla)*input_dim) for _ in range(num_layers)])
        self.measure = Measurement()

    def forward(self, psi):
        
        '''
        joint_state = psi.clone()

        for i in range(self.num_ancilla):
            batch_size = joint_state.shape[0]
            dim = joint_state.shape[1]
            ancilla = torch.zeros(batch_size, 2, dtype=torch.complex128, device=psi.device)
            ancilla[:, 0] = 1.0  # |0>
            joint_state = (ancilla.unsqueeze(2) * joint_state.unsqueeze(1)).reshape(batch_size, 2 * dim)

            joint_state = joint_state / joint_state.norm(dim=1, keepdim=True)
        '''

        batch_size = psi.shape[0]
        dim = psi.shape[1]

        joint_state = torch.tensor(np.zeros((batch_size,(2**self.num_ancilla)*dim), dtype=complex), dtype=torch.complex128)
 

        ancilla_1 = torch.tensor([1, 0], dtype=torch.complex128)
        ancilla = torch.tensor([1], dtype=torch.complex128)

        for j in range(self.num_ancilla):     
            ancilla =  torch.kron(ancilla_1, ancilla)


        for j in range(batch_size):
            #print(ancilla)
            joint_state[j,:] = torch.kron(ancilla, psi[j,:])


        joint_state = joint_state / joint_state.norm(dim=1, keepdim=True)
        
        for layer in self.quantum_layers_with_ancilla:
            joint_state = layer(joint_state)
            
        psi_sys0, probs = self.measure(joint_state, self.num_ancilla)
        return psi_sys0, probs


In [7]:
def load_data(N, new_size, number_train_samples, data_dir):
    def load_MNIST_dataset_Images(file_path): # for now trial on mnist dataset
        images = idx2numpy.convert_from_file(file_path)
        images = images.reshape(images.shape[0], -1)
        images = images.astype(np.float64)
        return images

    def load_MNIST_dataset_Labels(file_path):
        labels = idx2numpy.convert_from_file(file_path)
        return labels

    # Construct full paths
    train_images_file = os.path.join(data_dir, "train-images.idx3-ubyte")
    train_labels_file = os.path.join(data_dir, "train-labels.idx1-ubyte")


    mnist_images = load_MNIST_dataset_Images(train_images_file)
    mnist_labels = load_MNIST_dataset_Labels(train_labels_file)

    # Separating the load dataset into train labels and the train digits 

    def separate_images_and_labels(images , labels , digit):
        digit_images = []
        digit_labels = []
        for image,label in zip(images,labels):
            if label == digit:
                digit_images.append(image)
                digit_labels.append(label)
        return digit_images,digit_labels

    digits_to_separate = [0,1,2,3,4,5,6,7,8,9]

    digit_image = {}
    digit_label = {}

    for digit in digits_to_separate:
        digit_image[digit], digit_label[digit] = separate_images_and_labels(mnist_images,mnist_labels,digit)

    num_samples = 10 
    '''
    for digit in digit_image.keys():
        sample_images = digit_image[digit][:num_samples]
        plt.figure(figsize = (10,2))
        for i in range(num_samples):
            plt.subplot(1, num_samples, i + 1)
            plt.imshow(sample_images[i].reshape(28,28), cmap = 'gray')
            plt.title(f"Digit: {digit}")
            plt.axis('off')
        plt.show()
    '''
    # Resize the image to a higher accuracy and match the dimensions of the qubits

    from skimage.transform import resize

    def resize_images(images, new_size = (32,32)):
        resized_images = []
        for image in images:
            resized_image = resize(image.reshape(28,28), new_size)
            resized_images.append(resized_image.flatten())
        return resized_images

    resized_digit_images = {}
    for digit in digit_image.keys():
        resized_digit_images[digit] = resize_images(digit_image[digit], new_size = new_size)

    '''
    for digit in digit_image.keys():
        sample_images = resized_digit_images[digit][:num_samples]
        plt.figure(figsize =(10,2))
        for i in range(num_samples):
            plt.subplot(1, num_samples, i+1)
            plt.imshow(sample_images[i].reshape(new_size), cmap = 'magma')
            plt.title(f"Digit:{digit}")
            plt.axis('off')
        plt.show()
    '''

    # Converting all the digits in to arrays
    Digit_zero = []

    for image in resized_digit_images[0]:
        Digit_zero_ = image
        Digit_zero.append(Digit_zero_)

    Digit_zero = np.array(Digit_zero)

    #print('Digit zero',Digit_zero)

    Digit_one = []

    for image in resized_digit_images[1]:
        Digit_one_ = image
        Digit_one.append(Digit_one_)

    Digit_one = np.array(Digit_one)

    #print('Digit one',Digit_one)

    Digit_two = []

    for image in resized_digit_images[2]:
        Digit_two_ = image
        Digit_two.append(Digit_two_)

    Digit_two = np.array(Digit_two)

    #print('Digit two',Digit_two)

    Digit_three = []

    for image in resized_digit_images[3]:
        Digit_three_ = image
        Digit_three.append(Digit_three_)

    Digit_three = np.array(Digit_three)

    #print('Digit three',Digit_three)

    Digit_four = []

    for image in resized_digit_images[4]:
        Digit_four_ = image
        Digit_four.append(Digit_four_)

    Digit_four = np.array(Digit_four)

    #print('Digit four',Digit_four)

    Digit_fifth = []

    for image in resized_digit_images[5]:
        Digit_fifth_ = image
        Digit_fifth.append(Digit_fifth_)

    Digit_fifth = np.array(Digit_fifth)

    #print('Digit fifth',Digit_fifth)

    Digit_sixth = []

    for image in resized_digit_images[6]:
        Digit_sixth_ = image
        Digit_sixth.append(Digit_sixth_)

    Digit_sixth = np.array(Digit_sixth)

    #print('Digit Sixth',Digit_sixth)

    Digit_seventh = []

    for image in resized_digit_images[7]:
        Digit_seventh_ = image
        Digit_seventh.append(Digit_seventh_)

    Digit_seventh = np.array(Digit_seventh)

    #print('Digit seventh',Digit_seventh)

    Digit_eigth = []

    for image in resized_digit_images[8]:
        Digit_eigth_ = image
        Digit_eigth.append(Digit_eigth_)

    Digit_eigth = np.array(Digit_eigth)

    #print('Digit eigth',Digit_eigth)

    Digit_nineth = []

    for image in resized_digit_images[9]:
        Digit_nineth_ = image
        Digit_nineth.append(Digit_nineth_)

    Digit_nineth = np.array(Digit_nineth)

    #print('Digit nineth',Digit_nineth)

    # Digit zero to Nine grey scale matrix conversion by reshaping a row array into a (N,N) size

    # Quantum Amplitude Encoding of MNIST Dataset

    Digit_zero_ = []

    for i in range(len(Digit_zero)):
        Digit_zero_.append((Digit_zero[i].reshape(N**2,1))/np.linalg.norm(Digit_zero[i]))

    Digit_one_ = []

    for i in range(len(Digit_one)):
        Digit_one_.append((Digit_one[i].reshape(N**2,1))/np.linalg.norm(Digit_one[i]))

    Digit_two_ = []

    for i in range(len(Digit_two)):
        Digit_two_.append((Digit_two[i].reshape(N**2,1))/np.linalg.norm(Digit_two[i]))

    Digit_three_ = []

    for i in range(len(Digit_three)):
        Digit_three_.append((Digit_three[i].reshape(N**2,1))/np.linalg.norm(Digit_three[i]))

    Digit_four_ = []

    for i in range(len(Digit_four)):
        Digit_four_.append((Digit_four[i].reshape(N**2,1))/np.linalg.norm(Digit_four[i]))

    Digit_five_ = []

    for i in range(len(Digit_fifth)):
        Digit_five_.append((Digit_fifth[i].reshape(N**2,1))/np.linalg.norm(Digit_fifth[i]))

    Digit_sixth_ = []

    for i in range(len(Digit_sixth)):
        Digit_sixth_.append((Digit_sixth[i].reshape(N**2,1))/np.linalg.norm(Digit_sixth[i]))

    Digit_seventh_ = []

    for i in range(len(Digit_seventh)):
        Digit_seventh_.append((Digit_seventh[i].reshape(N**2,1))/np.linalg.norm(Digit_seventh[i]))

    Digit_eigth_ = []

    for i in range(len(Digit_eigth)):
        Digit_eigth_.append((Digit_eigth[i].reshape(N**2,1))/np.linalg.norm(Digit_eigth[i]))

    Digit_nineth_ = []

    for i in range(len(Digit_nineth)):
        Digit_nineth_.append((Digit_nineth[i].reshape(N**2,1))/np.linalg.norm(Digit_nineth[i]))


    Input_state_vector_zero = np.array(Digit_zero_[:number_train_samples])
    Input_state_vector_one = np.array(Digit_one_[:number_train_samples])
    Input_state_vector_two = np.array(Digit_two_[:number_train_samples])
    Input_state_vector_three = np.array(Digit_three_[:number_train_samples])
    Input_state_vector_four = np.array(Digit_four_[:number_train_samples])
    Input_state_vector_fifth = np.array(Digit_five_[:number_train_samples])
    Input_state_vector_sixth = np.array(Digit_sixth_[:number_train_samples])
    Input_state_vector_seventh = np.array(Digit_seventh_[:number_train_samples])
    Input_state_vector_eigth = np.array(Digit_eigth_[:number_train_samples])
    Input_state_vector_nineth = np.array(Digit_nineth_[:number_train_samples])

    input_state_vectors = []

    input_state_vectors.append(Input_state_vector_zero)
    input_state_vectors.append(Input_state_vector_one)
    input_state_vectors.append(Input_state_vector_two)
    input_state_vectors.append(Input_state_vector_three)
    input_state_vectors.append(Input_state_vector_four)
    input_state_vectors.append(Input_state_vector_fifth)
    input_state_vectors.append(Input_state_vector_sixth)
    input_state_vectors.append(Input_state_vector_seventh)
    input_state_vectors.append(Input_state_vector_eigth)
    input_state_vectors.append(Input_state_vector_nineth)

    input_state_vectors = np.array(input_state_vectors, dtype=complex)

    input_state_vectors = input_state_vectors.reshape(number_train_samples*10,N**2)

    input_state_vectors = np.array(input_state_vectors, dtype=complex)

    # Convert numpy array to PyTorch tensor
    input_vectors = torch.tensor(input_state_vectors, dtype=torch.complex128)

    target_states_combined = np.zeros((number_train_samples*10,N**2), dtype=complex)

    encode_length = (N**2)//9

    for i in range(10):
        start_row, end_row = i*number_train_samples, (i+1)*number_train_samples
        start_col = i*encode_length

        target_states_combined[start_row:end_row, start_col] = 1

    target_states_combined = torch.tensor(target_states_combined, dtype=torch.complex128)

    # Load MNIST test dataset

    # Construct full paths
    test_images_file = os.path.join(data_dir, "t10k-images.idx3-ubyte")
    test_labels_file = os.path.join(data_dir, "t10k-labels.idx1-ubyte")


    images = idx2numpy.convert_from_file(test_images_file)# Download the MNIST Dataset from Kaggle and Change this directory to your Directory
    labelss = idx2numpy.convert_from_file(test_labels_file)

    # Resize images to 32x32
    images_resized = np.array([np.array(Image.fromarray(img).resize(new_size)) for img in images])

    # Normalize vectors
    test_images = images_resized 

    input_state_test_vectors = []


    # Digit zero to Nine grey scale matrix conversion by reshaping a row array into a (N,N) size

    for i in range(len(images_resized)):
        input_state_test_vectors.append((images_resized[i].reshape(N**2,1))/np.linalg.norm(images_resized[i]))

    input_state_test_vectors = np.array(input_state_test_vectors,dtype = np.complex128)
    input_state_test_vectors = torch.tensor(input_state_test_vectors,dtype = torch.complex128)
    input_state_test_vectors = input_state_test_vectors.squeeze(2)
   

    return input_vectors, target_states_combined, input_state_test_vectors, labelss, encode_length


In [None]:
number_layers_arr = [4, 8, 12, 24]
N_arr = [4, 8, 16]
number_train_samples_arr = [1000, 2000, 3000] #number of train images per class
number_ancilla_arr = [1, 2]

In [None]:
# Create a new directory with current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"results_{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Define output file path
output_file = os.path.join(output_dir, "results.csv")

# Write the header directly since it's a new file
with open(output_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Iteration", "No: Pixels", "No: Train samples", "No: Layers", "No: Ancilla Qubits", "Train Accuracy", "Test Accuracy", "Time Stamp"])

print(f"{'Iteration':<12}{'Pixels':<10}{'Train Samples':<15}{'Layers':<10}{'Ancilla Qubits':<18}{'Train Acc':<12}{'Test Acc':<10}{'Timestamp':<20}")

iter_no = 1

for number_ancilla in number_ancilla_arr:
    for N in N_arr:
        for number_train_samples in number_train_samples_arr:
            for number_layers in number_layers_arr:
                '''
                N = 8
                number_train_samples = 100 #number of train images per class
                number_layers = 6
                number_ancilla = 1
                '''
                new_size = (N,N)
                number_test_samples = 5000
                # Define the base data directory
                data_dir = r"E:\Technical\SchrodingerAI\data\RawDataSets\MNIST"


                input_vectors, target_states_combined, input_state_test_vectors, labelss, encode_length = load_data(N, new_size, number_train_samples, data_dir)

                model = QiNN(input_dim=N**2, num_layers=number_layers, num_ancilla = number_ancilla)
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
                #loss_fn = nn.CrossEntropyLoss()
                loss_fn = nn.MSELoss()

                num_epochs = 50
                batch_size = number_train_samples * 10

                x = input_vectors[0:batch_size, :]
                target_labels = target_states_combined[0:batch_size, :]

                for epoch in range(num_epochs):
                    optimizer.zero_grad()

                    output_joint_state, _ = model(x)

                    # Normalize
                    output_joint_state = output_joint_state / output_joint_state.norm(dim=1, keepdim=True)
                    target_labels = target_labels / target_labels.norm(dim=1, keepdim=True)

                    # Convert to probability
                    output_probs = torch.abs(output_joint_state) ** 2
                    target_probs = torch.abs(target_labels) ** 2

                    # Compute loss
                    loss = loss_fn(output_probs, target_probs)
                    loss.backward()
                    optimizer.step()

                #print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")

                test_sample_number = 2499

                # Select test sample and target
                #x_test = input_vectors[test_sample_number, :].unsqueeze(0)
                #target_labels_test = target_states_combined[test_sample_number, :].unsqueeze(0)

                x_test = input_vectors #[test_sample_number, :].unsqueeze(0)
                target_labels_test = target_states_combined #[test_sample_number, :].unsqueeze(0)


                # Run through model
                output_joint_state_test, _ = model(x_test)


                # Normalize
                output_joint_state_test = output_joint_state_test / output_joint_state_test.norm(dim=1, keepdim=True)
                target_labels_test = target_labels_test / target_labels_test.norm(dim=1, keepdim=True)

                # Convert to probability
                output_probs_test = torch.abs(output_joint_state_test) ** 2
                target_probs_test = torch.abs(target_labels_test) ** 2

                predicted_class_test = torch.argmax(output_probs_test, dim=1) // encode_length
                true_class_test = torch.argmax(target_probs_test, dim=1) // encode_length

                # Compare predicted vs true classes
                correct = (predicted_class_test == true_class_test)  # Boolean
                num_correct = correct.sum().item()                   # Number of True predictions
                total = predicted_class_test.shape[0]                # Total number of predictions

                train_accuracy = (num_correct / total) * 100               # Accuracy in percentage

                #print(f"Correct Predictions: {num_correct}/{total}")
                #print(f"Train Accuracy: {train_accuracy:.2f}%")


                # Select test sample and target
                #x_test = input_vectors[test_sample_number, :].unsqueeze(0) 
                #target_labels_test = target_states_combined[test_sample_number, :].unsqueeze(0)  
                x_test = input_state_test_vectors #[test_sample_number, :].unsqueeze(0)  

                # Run model
                output_joint_state_test, _ = model(x_test)

                # Normalize
                output_joint_state_test = output_joint_state_test / output_joint_state_test.norm(dim=1, keepdim=True)

                # Convert to probability
                output_probs_test = torch.abs(output_joint_state_test) ** 2

                predicted_class_test = torch.argmax(output_probs_test, dim=1) // encode_length
                true_class_test = torch.tensor(labelss, dtype = torch.complex128)

                # Compare predicted vs true classes
                correct = (predicted_class_test == true_class_test)  # Boolean 
                num_correct = correct.sum().item()                   # Number of True  predictions
                total = predicted_class_test.shape[0]                # Total number of predictions

                test_accuracy = (num_correct / total) * 100               # Accuracy in percentage

                #print(f"Correct Predictions: {num_correct}/{total}")
                #print(f"Test Accuracy: {test_accuracy:.2f}%")

                timestamp_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print(f"{iter_no:<12}{N:<10}{number_train_samples:<15}{number_layers:<10}{number_ancilla:<18}{train_accuracy:<12.4f}{test_accuracy:<10.4f}{timestamp_now}")

                #print([str(N), str(number_train_samples), str(number_layers), str(number_ancilla),  str(train_accuracy), str(test_accuracy)])
                with open(output_file, mode='a', newline='') as file:
                    writer = csv.writer(file)
                    writer.writerow([
                        str(iter_no),
                        str(N),
                        str(number_train_samples),
                        str(number_layers),
                        str(number_ancilla),
                        f"{train_accuracy:.4f}",
                        f"{test_accuracy:.4f}",
                        timestamp_now
                    ])
                                                                    
                iter_no += 1


Iteration   Pixels    Train Samples  Layers    Ancilla Qubits    Train Acc   Test Acc  Timestamp           
1           16        1000           4         2                 90.0200     89.4000   2025-07-06 23:13:04
2           16        1000           8         2                 88.1600     87.2600   2025-07-07 01:40:00
3           16        1000           12        2                 89.1900     88.4300   2025-07-07 05:20:27
4           16        1000           24        2                 63.6800     61.7900   2025-07-07 14:28:20
5           16        2000           4         2                 86.0350     86.6700   2025-07-07 16:21:55
6           16        2000           8         2                 88.3900     88.7300   2025-07-07 19:56:02
7           16        2000           12        2                 84.2700     84.8900   2025-07-08 00:53:34
