# Quantum k-means Clustering on local quantum simulator

To run this program, please follow these steps:
* Download the file.
* Install Qiskit, following this instruction: https://docs.quantum.ibm.com/start/install#local
* Install sklearn using this command: `pip install scikit-learn`
* Restart Kernel and Run All Cells

In [None]:
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit.primitives import StatevectorSampler
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit_ibm_runtime import QiskitRuntimeService, SamplerV2 as Sampler

import numpy as np
import math
import random
import itertools
import ast

import matplotlib.pyplot as plt
import pandas as pd

from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans

"""
Construct functions for changing Cartesian coordinate to Polar coordinate
"""
def polar_angle(point):
    """
    Calculating polar angle
    """
    polar_angle = math.atan2(point[1], point[0])
    return polar_angle

def polar_radius(point):
    """
    Calculating polar coordinate
    """
    polar_radius = math.sqrt(point[0] ** 2 + point[1] ** 2)
    return polar_radius

def initialize_clusters(k, dataset):
    clusters = {}
    # Randomly choose k points from the dataset as centroids
    centroids = random.sample(dataset, k)
    # Create a dictionary to store points in the cluster represented by centroids
    for centroid in centroids:
        clusters[f"{centroid}"] = []
    return clusters, centroids

def recompute_centroids(clusters, centroids):
    """
    Construct function for recomputing centroids as the average of values in the clusters
    """
    is_equal = True
    new_clusters = {}
    new_centroids = []
    for old_centroids_string, cluster in clusters.items():
        old_centroid = ast.literal_eval(old_centroids_string)
        if len(cluster) > 0:
            new_centroid = np.mean(cluster, axis=0).tolist()
            if new_centroid != old_centroid:
                is_equal = False
            new_clusters[f"{new_centroid}"] = []
            new_centroids.append(new_centroid)
    if is_equal == False:
        return new_clusters, new_centroids, is_equal
    else:
        return clusters, centroids, is_equal

def quantum_circuit(qc, qr, cr, i, data_angle, centroid_angle):
    qc.h(qr[i * 2])
    qc.cx(qr[i * 2], qr[i * 2 + 1])
    qc.ry(-abs(data_angle - centroid_angle), qr[i * 2 + 1])
    qc.cx(qr[i * 2], qr[i * 2 + 1])
    qc.ry(abs(data_angle - centroid_angle), qr[i * 2 + 1])
    # Inteferernce and measurement
    qc.h(qr[i * 2])
    qc.measure(qr[i * 2], cr[i])
    return qc

def nearest_centroids_dictionary(dataset):
    """
    Construct a dictionary to store the information on nearest centroid for each point in the dataset
    """
    nearest_centroids_dict = {}
    for data_point in dataset:
        nearest_centroids_dict[f"{data_point}"] = {
            'smallest distance': float('inf'),
            'nearest centroid': None
        }
    return nearest_centroids_dict

def khan_quantum_distance(calculated_distances, nearest_centroids_dict, num_shots=1024):
    for idx, item in calculated_distances.items():
        probability_1 = item['count 1'] / num_shots
        cen = item['pair'][0]
        data = item['pair'][1]
        # Calculate the distance from probability of |1>
        normalize_value = math.sqrt(data[0] ** 2 + data[1] ** 2 + cen[0] ** 2 + cen[1] ** 2)
        quantum_dist = normalize_value * math.sqrt(2 * probability_1)
        # Compare with classical distance
        classical_dist = classical_distance(data, cen)
        print(f"Khan quantum distance - {quantum_dist} - Classical distance - {classical_dist}")
        # Update the nearest centroids dictionary
        if quantum_dist < nearest_centroids_dict[f"{data}"]['smallest distance']:
            nearest_centroids_dict[f"{data}"]['smallest distance'] = quantum_dist
            nearest_centroids_dict[f"{data}"]['nearest centroid'] = cen
    return nearest_centroids_dict

def duong_quantum_distance(calculated_distances, nearest_centroids_dict, num_shots=1024):
    for idx, item in calculated_distances.items():
        probability_1 = item['count 1'] / num_shots
        cen = item['pair'][0]
        data = item['pair'][1]
        # Calculate the distance from the probability of |1>
        data_radius = polar_radius(data)
        centroid_radius = polar_radius(cen)
        quantum_dist = math.sqrt((data_radius - centroid_radius) ** 2 + 4 * data_radius * centroid_radius * probability_1)
        # Compare with classical distance
        classical_dist = classical_distance(data, cen)
        print(f"Duong quantum distance - {quantum_dist} - Classical distance - {classical_dist}")
        # Update the nearest centroids dictionary
        if quantum_dist < nearest_centroids_dict[f"{data}"]['smallest distance']:
            nearest_centroids_dict[f"{data}"]['smallest distance'] = quantum_dist
            nearest_centroids_dict[f"{data}"]['nearest centroid'] = cen
    return nearest_centroids_dict

def classical_distance(point_1, point_2):
    """
    Classically calculate the Euclidean distance between 2 points
    """
    euclidean_distance = math.sqrt((point_1[0] - point_2[0]) ** 2 + (point_1[1] - point_2[1]) ** 2)
    return euclidean_distance

def classical_selection(dataset, centroids):
    nearest_centroids_dict = nearest_centroids_dictionary(dataset)
    for data in dataset:
        for centroid in centroids:
            classical_dist = classical_distance(data, centroid)
            if classical_dist < nearest_centroids_dict[f"{data}"]['smallest distance']:
                nearest_centroids_dict[f"{data}"]['smallest distance'] = classical_dist
                nearest_centroids_dict[f"{data}"]['nearest centroid'] = centroid
    return nearest_centroids_dict

def khan_quantum_selection(dataset, centroids, num_qubits, num_shots=1024):
    # Calculate the number of distances can be calculated at the same time
    num_distances = 0
    num_distances, remainder = divmod(num_qubits, 2)
    # Initialize the dictionary contains information about nearest centroids for each data point
    nearest_centroids_dict = nearest_centroids_dictionary(dataset)
    # Initialize a list contains all pairs of points in the data point
    pairs = []
    for centroid in centroids:
        for data in dataset:
            pairs.append([centroid, data])
    # Loop until all pairs' distances are calculated
    while len(pairs) > 0:
        # Initialize quantum circuit
        qr = QuantumRegister(num_distances * 2, name='q')
        cr = ClassicalRegister(num_distances, name='c')
        qc = QuantumCircuit(qr, cr)
        # Initialize dictionary contains all pairs of points whose distances are going to be calculated
        calculated_distances = {}
        i = 0
        # Loop until the circuit is full
        while (len(calculated_distances) < num_distances) and (len(pairs) > 0):
            pair = pairs.pop()
            centroid = pair[0]
            data_point = pair[1]
            centroid_angle = polar_angle(centroid)
            data_angle = polar_angle(data_point)
            qc = quantum_circuit(qc, qr, cr, i, data_angle, centroid_angle)
            # Update the index and calculated_distances
            calculated_distances[i] = {'pair': pair, 'count 1': 0}
            i += 1
        sampler = StatevectorSampler()
        job = sampler.run([qc], shots=num_shots)
        result = job.result()[0]
        count = result.data.c.get_counts()
        state_list = list(count.keys())
        for state in state_list:
            state_count = count.get(state, 0)
            for idx, bit in enumerate(state):
                if bit == '1':
                    if (len(calculated_distances) - idx - 1) >= 0:
                        calculated_distances[len(calculated_distances) - idx - 1]['count 1'] += state_count
        nearest_centroids_dict = khan_quantum_distance(calculated_distances, nearest_centroids_dict, num_shots)
    return nearest_centroids_dict

def duong_quantum_selection(dataset, centroids, num_qubits, num_shots=1024):
    # Calculate the number of distances can be calculated at the same time
    num_distances = 0
    num_distances, remainder = divmod(num_qubits, 2)
    # Initialize the dictionary contains information about nearest centroids for each data point
    nearest_centroids_dict = nearest_centroids_dictionary(dataset)
    # Initialize a list contains all pairs of points in the data point
    pairs = []
    for centroid in centroids:
        for data in dataset:
            pairs.append([centroid, data])
    # Loop until all pairs' distances are calculated
    while len(pairs) > 0:
        # Initialize quantum circuit
        qr = QuantumRegister(num_distances * 2, name='q')
        cr = ClassicalRegister(num_distances, name='c')
        qc = QuantumCircuit(qr, cr)
        # Initialize dictionary contains all pairs of points whose distances are going to be calculated
        calculated_distances = {}
        i = 0
        # Loop until the circuit is full
        while (len(calculated_distances) < num_distances) and (len(pairs) > 0):
            pair = pairs.pop()
            centroid = pair[0]
            data_point = pair[1]
            centroid_angle = polar_angle(centroid)
            data_angle = polar_angle(data_point)
            qc = quantum_circuit(qc, qr, cr, i, data_angle, centroid_angle)
            # Update the index and calculated_distances
            calculated_distances[i] = {'pair': pair, 'count 1': 0}
            i += 1
        sampler = StatevectorSampler()
        job = sampler.run([qc], shots=num_shots)
        result = job.result()[0]
        count = result.data.c.get_counts()
        state_list = list(count.keys())
        for state in state_list:
            state_count = count.get(state, 0)
            for idx, bit in enumerate(state):
                if bit == '1':
                    if (len(calculated_distances) - idx - 1) >= 0:
                        calculated_distances[len(calculated_distances) - idx - 1]['count 1'] += state_count
        nearest_centroids_dict = duong_quantum_distance(calculated_distances, nearest_centroids_dict, num_shots)
    return nearest_centroids_dict

def khan_quantum_clustering(k, dataset, clusters, centroids, num_qubits, duong_iteration, num_shots=1024):
    is_convergence = False
    iteration_count = 0
    repeated_point = [0.34600960858349594, -0.15629187416923862]
    # Loop until convergence
    while (not is_convergence) and (iteration_count < duong_iteration):
        iteration_count += 1
        # Run Khan's quantum program
        khan_dict = khan_quantum_selection(dataset, centroids, num_qubits, num_shots)
        classical_dict = classical_selection(dataset, centroids)
        # Compare Khan's choice with classical choice
        same_count = 0
        different_count = 0
        for data_string, data_dict in khan_dict.items():
            # Count the number of times Khan and Classical methods chose the same (different) centroids
            if data_dict['nearest centroid'] == classical_dict[data_string]['nearest centroid']:
                same_count += 1
            else:
                different_count += 1
            # Read and update the information to the clusters dictionary
            data = ast.literal_eval(data_string)
            nearest_centroid = data_dict['nearest centroid']
            clusters[f"{nearest_centroid}"].append(data)
        print(f"Khan same count: {same_count} and different count: {different_count}")
        # Add the repeated point to the clusters dictionary
        for centroid_string, cluster in clusters.items():
            if repeated_point in cluster:
                cluster.append(repeated_point)
        # Check if enough clusters have been created
        is_k_clusters = True
        for centroid_string, cluster in clusters.items():
            if len(cluster) == 0:
                is_k_clusters = False
        # If not enough clusters, repeat the upper task with another k randomly generated centroids
        if is_k_clusters == False:
            clusters, centroids = initalize_clusters(k, dataset)
            continue
        else:
            # Recompute centroids by averaging all points
            new_clusters, new_centroids, is_equal = recompute_centroids(clusters, centroids)
            if is_equal == True:
                is_convergence = True
                break
            else:
                clusters = new_clusters
                centroids = new_centroids
    print(f"Khan's iteration count is: {iteration_count}")
    return clusters, centroids

def duong_quantum_clustering(k, dataset, clusters, centroids, num_qubits, num_shots=1024):
    is_convergence = False
    iteration_count = 0
    repeated_point = [0.34600960858349594, -0.15629187416923862]
    # Loop until convergence
    while not is_convergence:
        iteration_count += 1
        # Run Duong's quantum program
        duong_dict = duong_quantum_selection(dataset, centroids, num_qubits, num_shots)
        classical_dict = classical_selection(dataset, centroids)
        # Compare Duong's choice with classical choice
        same_count = 0
        different_count = 0
        for data_string, data_dict in duong_dict.items():
            # Count the number of times Duong and Classical methods chose the same (different) centroids
            if data_dict['nearest centroid'] == classical_dict[data_string]['nearest centroid']:
                same_count += 1
            else:
                different_count += 1
            # Read and update the information to the clusters dictionary
            data = ast.literal_eval(data_string)
            nearest_centroid = data_dict['nearest centroid']
            clusters[f"{nearest_centroid}"].append(data)
        print(f"Duong same count: {same_count} and different count: {different_count}")
        # Add the repeated point to the clusters dictionary
        for centroid_string, cluster in clusters.items():
            if repeated_point in cluster:
                cluster.append(repeated_point)
        # Check if enough clusters have been created
        is_k_clusters = True
        for centroid_string, cluster in clusters.items():
            if len(cluster) == 0:
                is_k_clusters = False
        # If not enough clusters, repeat the upper task with another k randomly generated centroids
        if is_k_clusters == False:
            clusters, centroids = initalize_clusters(k, dataset)
            continue
        else:
            # Recompute centroids by averaging all points
            new_clusters, new_centroids, is_equal = recompute_centroids(clusters, centroids)
            if is_equal == True:
                is_convergence = True
                break
            else:
                clusters = new_clusters
                centroids = new_centroids
    print(f"Duong's iteration count is: {iteration_count}")
    return clusters, centroids, iteration_count

def run_program(k, dataset, correct_clusters, num_qubits, num_shots=1024):
    # Initialize empty clusters with randomly chosen centroids
    clusters, centroids = initialize_clusters(k, dataset)
    print(f"Original centroids {centroids}")
    duong_clusters, duong_centroids, duong_iteration = duong_quantum_clustering(k, dataset, clusters, centroids, num_qubits, num_shots)
    khan_clusters, khan_centroids = khan_quantum_clustering(k, dataset, clusters, centroids, num_qubits, duong_iteration, num_shots)
    # Create plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    # Khan plot
    ax1 = axes[0]
    # Plot original centroids
    for centroid in centroids:
        ax1.scatter(centroid[0], centroid[1], marker='o', s=100, color='black')
    for centroid_string, cluster in khan_clusters.items():
        # Plot centroid
        centroid = ast.literal_eval(centroid_string)
        ax1.scatter(centroid[0], centroid[1], marker='x', s=100, color='black')
        # Plot points
        x_values = []
        y_values = []
        for point in cluster:
            x_values.append(point[0])
            y_values.append(point[1])
        ax1.scatter(x_values, y_values)
    ax1.set_xlabel('Feature 1')
    ax1.set_ylabel('Feature 2')
    ax1.set_title('Khan et al. result')
    # Duong plot
    ax2 = axes[1]
    # Plot original centroids
    for centroid in centroids:
        ax2.scatter(centroid[0], centroid[1], marker='o', s=100, color='black')
    for centroid_string, cluster in duong_clusters.items():
        # Plot centroids
        centroid = ast.literal_eval(centroid_string)
        ax2.scatter(centroid[0], centroid[1], marker='x', s=100, color='black')
        # Plot points
        x_values = []
        y_values = []
        for point in cluster:
            x_values.append(point[0])
            y_values.append(point[1])
        ax2.scatter(x_values, y_values)
    ax2.set_xlabel('Feature 1')
    ax2.set_ylabel('Feature 2')
    ax2.set_title('New method result')
    # Correct labels plot
    ax3 = axes[2]
    # Plot original centroids
    for centroid in centroids:
        ax3.scatter(centroid[0], centroid[1], marker='o', s=100, color='black')
    for centroid_string, cluster in correct_clusters.items():
        # Plot centroids
        centroid = ast.literal_eval(centroid_string)
        ax3.scatter(centroid[0], centroid[1], marker='x', s=100, color='black')
        # Plot points
        x_values = []
        y_values = []
        for point in cluster:
            x_values.append(point[0])
            y_values.append(point[1])
        ax3.scatter(x_values, y_values)
    ax3.set_xlabel('Feature 1')
    ax3.set_ylabel('Feature 2')
    ax3.set_title('Correct labels')
    # Save the plot to a file
    plt.savefig("k_means_compare1.png")
    return duong_clusters, khan_clusters, correct_clusters

# Import Iris dataset
iris_data = load_iris()
features = iris_data.data
labels = iris_data.target
# Apply MinMaxScaler to map data onto (0, 1)
features = MinMaxScaler().fit_transform(features)
# Reduce the number of features
features = PCA(n_components=2).fit_transform(features)
# Change features to list and create list which store correct labels
data_list = features.tolist()
correct_list = labels.tolist()
# Create a clusters dictionary from the information obtained above
correct_clusters = {0: [], 1: [], 2: []}
for i, label in enumerate(correct_list):
    if label == 0:
        correct_clusters[0].append(data_list[i])
    elif label == 1:
        correct_clusters[1].append(data_list[i])
    else:
        correct_clusters[2].append(data_list[i])
# Change the dictionary keys to centroids
new_correct_clusters = {}
for i, cluster in correct_clusters.items():
    centroid = np.mean(cluster, axis=0).tolist()
    new_correct_clusters[f"{centroid}"] = cluster
correct_clusters = new_correct_clusters
# Run the experiment
duong_clusters, khan_clusters, correct_clusters = run_program(3, data_list, correct_clusters, 10, 2048)