# 1. Import Libraries

In [None]:
import time
import warnings
import pandas as pd
import os

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
from qiskit import ClassicalRegister, QuantumRegister
from qiskit import QuantumCircuit
from qiskit_algorithms.optimizers import COBYLA
from qiskit.circuit.library import RealAmplitudes
from qiskit.quantum_info import Statevector
from qiskit.utils import algorithm_globals
from qiskit_machine_learning.circuit.library import RawFeatureVector
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit.extensions import RZGate, CXGate
from qiskit_machine_learning.algorithms.classifiers import VQC
from qiskit.primitives import Sampler
from qiskit.circuit.library import EfficientSU2

import warnings
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
from torchvision import datasets, transforms

# 2. Setup Environment Variables

In [None]:
algorithm_globals.random_seed = 42

In [None]:
warnings.filterwarnings('ignore')

In [None]:
dataset = 'FMNIST'

In [None]:
num_latent = 4
num_trash = 2

In [None]:
image_shape = (8,8)
image_pixels = image_shape[0]*image_shape[1]

num_samples = 25
num_classes = 2

In [None]:
batch_size = num_samples*num_classes
projection_units = 32
temperature = 0.1

tolerance = 0.9

In [None]:
experiment_name = "DATASET_{}-IMG_{}-CLASSES_{}-SAMPLES_{}-BATCH_{}-PROJUNITS_{}-TEMP_{}-LATENT_{}-TRASH_{}".\
format(dataset, str(image_shape[0])+"X"+str(image_shape[1]), num_classes, num_samples, batch_size, projection_units, temperature, num_latent,
       num_trash)

In [None]:
try:
    os.mkdir('experiments/{}'.format(experiment_name))
except FileExistsError:
    print("done")

In [None]:
experiment_folder = 'experiments/{}'.format(experiment_name)

# 2. Implement SupCon Learning Framework

![q-sup-con](./images/q_sup_con_model.png)

## 2.1 Quantum Encoder Circuit

In [None]:
def ansatz(num_qubits):
    return RealAmplitudes(num_qubits, reps=5)

In [None]:
def auto_encoder_circuit(num_latent, num_trash):
    qr = QuantumRegister(num_latent + 2 * num_trash + 1, "q")
    cr = ClassicalRegister(1, "c")
    circuit = QuantumCircuit(qr, cr)
    circuit.compose(ansatz(num_latent + num_trash), range(0, num_latent + num_trash), inplace=True)
    circuit.barrier()
    auxiliary_qubit = num_latent + 2 * num_trash
    # swap test
    circuit.h(auxiliary_qubit)
    for i in range(num_trash):
        circuit.cswap(auxiliary_qubit, num_latent + i, num_latent + num_trash + i)

    circuit.h(auxiliary_qubit)
    circuit.measure(auxiliary_qubit, cr[0])
    return circuit

## 2.2. Data Augmentation

In [None]:
def data_augmentation_circuit(num_latent, num_trash):
    num_qubits = num_latent + num_trash
    if num_qubits < 1:
        raise ValueError("Number of qubits must be at least 1.")

    qreg_q = QuantumRegister(num_latent + 2 * num_trash + 1, "q")
    # creg_c = ClassicalRegister(1, 'c')
    circuit = QuantumCircuit(qreg_q)
    # circuit.compose(ansatz(num_latent + num_trash), range(0, num_latent + num_trash), inplace=True)

    for i in range(num_qubits):
        circuit.sx(qreg_q[i])
        circuit.append(RZGate(np.pi / 2), [qreg_q[i]])
        circuit.sx(qreg_q[i])
        circuit.append(RZGate(np.pi / 2), [qreg_q[i]])

        if i < num_qubits - 1:
            circuit.append(CXGate(), [qreg_q[i], qreg_q[i + 1]])

        # circuit.measure(qreg_q[i], creg_c[0])

    return circuit

# 3. Data Loading

In [None]:
# Define transformations to apply to the MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(image_shape)])

# Download and load the MNIST training dataset
full_mn_train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
full_fmn_train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
full_kmn_train_dataset = datasets.KMNIST(root='./data', train=True, transform=transform, download=True)

selected_dataset = full_fmn_train_dataset if dataset == 'FMNIST' else ( full_kmn_train_dataset if dataset == 'KMNIST' else full_mn_train_dataset)

# Filter the dataset to get 25 samples of 0 and 25 samples of 1
# Initialize an empty list to store indices
selected_indices = []

# Iterate through class labels 0 to 9
for class_label in range(num_classes):
    indices = torch.where(selected_dataset.targets == class_label)[0][:num_samples]
    selected_indices.extend(indices.tolist())

train_dataset = torch.utils.data.Subset(selected_dataset, selected_indices)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Create an iterator from the DataLoader
train_iter = iter(train_loader)
# test_iter = iter(test_loader)

In [None]:
batch = next(train_iter)

# 4. Quantum Supervised Conrastive Learning

## 4.1 Build QSCL model

In [None]:
fm = RawFeatureVector(2 ** (num_latent + num_trash))

d_aug = data_augmentation_circuit(num_latent, num_trash)

ae = auto_encoder_circuit(num_latent, num_trash)

qc = QuantumCircuit(num_latent + 2 * num_trash + 1, 1)
qc = qc.compose(fm, range(num_latent + num_trash))
qc = qc.compose(d_aug)
qc = qc.compose(ae)

In [None]:
def identity_interpret(x):
    return x

In [None]:
sup_con_qnn = SamplerQNN(
    circuit=qc,
    input_params=fm.parameters,
    weight_params=ae.parameters,
    interpret=identity_interpret,
    output_shape=projection_units,
)

## 4.2 Build Cost Function

### 4.2.1 Supervised Contrastive Loss

In [None]:
objective_func_vals_sup_con = []

In [None]:
def cost_func_digits_sup(params_values):
    batch_images, batch_labels = batch

    batch_images = np.array(batch_images.reshape(len(batch_images), image_pixels))

    for i in range(len(batch_images)):
        sum_sq = np.sum(batch_images[i] ** 2)
        batch_images[i] = batch_images[i] / np.sqrt(sum_sq)
    
    # batch_images = (batch_images - batch_images.min()) / (batch_images.max() - batch_images.min()) * np.pi/2
    probabilities = sup_con_qnn.forward(batch_images, params_values)
    
    # Normalize feature vectors
    feature_vectors_magnitude = np.linalg.norm(probabilities, axis=1, ord=2, keepdims=True)
    feature_vectors_normalized = probabilities / feature_vectors_magnitude
    # Compute logits
    logits = np.dot(feature_vectors_normalized, feature_vectors_normalized.T) / temperature

    loss = -np.log(np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True))
    loss = loss[np.arange(len(batch_labels)), np.array(batch_labels).squeeze()]
    cost = np.mean(loss)

    # plotting part
    clear_output(wait=True)
    objective_func_vals_sup_con.append(cost)
    plt.title("Objective function value against iteration")
    plt.xlabel("Iteration")
    plt.ylabel("Objective function value")
    plt.plot(range(len(objective_func_vals_sup_con)), objective_func_vals_sup_con)
    plt.show()

    return cost

In [None]:
initial_points = algorithm_globals.random.random(ae.num_parameters)
try:
    initial_points = np.load(experiment_folder+'/enc_initial_points.npy')
except FileNotFoundError:
    initial_points = algorithm_globals.random.random(ae.num_parameters)
    np.save(experiment_folder+'/enc_initial_points.npy', np.array(initial_points))

## 4.3 QSCL Encoder Training

In [None]:
opt = COBYLA(maxiter=1000) # , tol=tolerance

# make the plot nicer
plt.rcParams["figure.figsize"] = (12, 6)

start = time.time()
opt_result = opt.minimize(cost_func_digits_sup, initial_points)
elapsed = time.time() - start
print(f"Fit in {elapsed:0.2f} seconds")

In [None]:
np.save(experiment_folder+'/enc_loss_values.npy', np.array(objective_func_vals_sup_con))

In [None]:
opt_result_x = algorithm_globals.random.random(ae.num_parameters)
try:
    opt_result_x = np.load(experiment_folder+'/enc_opt_x.npy')
except FileNotFoundError:
    opt_result_x = algorithm_globals.random.random(ae.num_parameters)
    np.save(experiment_folder+'/enc_opt_x.npy', np.array(opt_result.x))

## 4.4 Generate Feacture Vectors

In [None]:
# Test
test_qc = QuantumCircuit(num_latent + num_trash)
test_qc = test_qc.compose(fm)
ansatz_qc = ansatz(num_latent + num_trash)
test_qc = test_qc.compose(ansatz_qc)
test_qc.barrier()

for i in range(num_latent, num_latent+num_trash):
    test_qc.reset(i)
# test_qc.barrier()
# test_qc = test_qc.compose(ansatz_qc.inverse())

### 4.4.1 Generate Feacture Vector for Train Data

In [None]:
X_encoded = []
y_encoded = []

# sample new images
# test_images, test_labels = get_dataset_digits(2, draw=False)
for images, labels in train_loader:
    for image in images:
        image = np.array(image.reshape(image_pixels))

        original_qc = fm.assign_parameters(image)
        original_sv = Statevector(original_qc).data
        original_sv = np.reshape(np.abs(original_sv) ** 2, image_shape)

        param_values = np.concatenate((image, opt_result.x))
        output_qc = test_qc.assign_parameters(param_values)
        output_sv = Statevector(output_qc).data
        output_sv = np.reshape(np.abs(output_sv) ** 2, image_pixels)

        X_encoded.append(output_sv[:2**num_latent])
    y_encoded.extend(labels.cpu().numpy())

### 4.4.2 Generate Feacture Vector for Test Data

In [None]:
# Test
test_qc2 = QuantumCircuit(num_latent + num_trash)
test_qc2 = test_qc2.compose(fm)
ansatz_qc2 = ansatz(num_latent + num_trash)
test_qc2 = test_qc2.compose(ansatz_qc2)
test_qc2.barrier()

for i in range(num_latent, num_latent+num_trash):
    test_qc2.reset(i)
test_qc2.barrier()
test_qc2 = test_qc2.compose(ansatz_qc2.inverse())

In [None]:
y_encoded = np.reshape(y_encoded, (len(y_encoded), 1))

In [None]:
X_y_np = np.concatenate((X_encoded, y_encoded), axis=1)

In [None]:
X_y_df = pd.DataFrame(X_y_np, columns =['f'+str(i) for i in range(2**num_latent)]+['y'])
X_y_df.to_csv(experiment_folder+'/enc_df.csv') #dataset, image_size, num_samples, batch_size

# 8. VQC Classification

In [None]:
fm2 = RawFeatureVector(2**num_latent)

In [None]:
df = pd.read_csv(experiment_folder+'/enc_df.csv', index_col=0)

In [None]:
train_features, test_features, train_labels, test_labels = train_test_split(
    np.array(df.iloc[:,:-1]), np.array(df.y), train_size=0.8, random_state=42
)

In [None]:
objective_func_vals = []

In [None]:
# callback function that draws a live plot when the .fit() method is called
def callback_graph(weights, obj_func_eval):
    clear_output(wait=True)
    objective_func_vals.append(obj_func_eval)
    plt.title("Objective function value against iteration")
    plt.xlabel("Iteration")
    plt.ylabel("Objective function value")
    plt.plot(range(len(objective_func_vals)), objective_func_vals)
    plt.show()

In [None]:
sampler = Sampler()

ansatz = EfficientSU2(num_qubits=num_latent, reps=3)
optimizer = COBYLA(maxiter=1)

vqc = VQC(
    sampler=sampler,
    feature_map=fm2,
    ansatz=ansatz,
    optimizer=optimizer,
    callback=callback_graph,
)

In [None]:
# create empty array for callback to store evaluations of the objective function
objective_func_vals = []
plt.rcParams["figure.figsize"] = (12, 6)

# fit classifier to data

start = time.time()
vqc.fit(train_features, train_labels)
elapsed = time.time() - start

print(f"Training time: {round(elapsed)} seconds")

# return to default figsize
plt.rcParams["figure.figsize"] = (6, 4)

In [None]:
train_score_q2_eff = vqc.score(train_features, train_labels)
test_score_q2_eff = vqc.score(test_features, test_labels)

print(f"Quantum VQC on the training dataset using EfficientSU2: {train_score_q2_eff:.2f}")
print(f"Quantum VQC on the test dataset using EfficientSU2:     {test_score_q2_eff:.2f}")

In [None]:
np.save(experiment_folder+'/clasif_loss_values.npy', np.array(objective_func_vals))

In [None]:
with open(experiment_folder+'/results.txt', 'w') as file:
    file.write(f'{vqc.score(train_features, train_labels)}\n')
    file.write(f'{vqc.score(test_features, test_labels)}\n')

In [None]:
vqc.save(experiment_folder+'/classifire.model')