In [None]:
from peptides.classicalNeuralNetwork import Classical_NeuralNetwork
from peptides.quantumNeuralNetwork import Quantum_NeuralNetwork
from peptides.dataLoader import dataLoader

# Load and format data

In [None]:
dl = dataLoader()
X_pca, y = dl.load_data(allele = 'HLA-A*03:01',
                        PCA_dim = 4,
                        size_dataset = 100)

In [None]:
dl.plot_pca_variance()

In [None]:
dl.plot_pca_2D()

# Train Classical 

In [None]:
CNN = Classical_NeuralNetwork(X_pca, y)
loss_classical = CNN.train(n_splits = 2,
                        n_iterations = 80,
                        hidden_layer_sizes = 50)

# Train on qasm_simulator

In [None]:
from qiskit_aer import AerSimulator

simulator = AerSimulator()

QNN = Quantum_NeuralNetwork(X_pca,y, simulator)
loss_quantum_qasm = QNN.train(n_splits = 2,
            n_iterations = 1,
            num_features = 4,
            reps = 5)

# Train on simulation of ibm_quebec

In [None]:
from qiskit_ibm_provider import IBMProvider
from qiskit_aer import AerSimulator

provider = IBMProvider(instance="pinq-quebec-hub/ecole-dhiver/qml-workshop")
quebec = provider.get_backend('ibm_quebec')
quebec_simulator = AerSimulator.from_backend(quebec)

QNN = Quantum_NeuralNetwork(X_pca,y, backend = quebec_simulator)
loss_quantum_quebec = QNN.train(n_splits = 2,
                                n_iterations = 50,
                                num_features = 4,
                                reps = 5)

# Train on ibm_Quebec

In [None]:
from qiskit_ibm_provider import IBMProvider

provider = IBMProvider(instance="pinq-quebec-hub/ecole-dhiver/qml-workshop")
simulator = provider.get_backend('ibm_quebec')

#QNN = Quantum_NeuralNetwork(X_pca,y, simulator)
#QNN.train(n_splits = 2,
#            n_iterations = 50,
#            num_features = 3,
#            reps = 5)

# Plot losses

In [None]:
import matplotlib.pyplot as plt

with plt.style.context("seaborn-v0_8"):
    plt.figure(figsize=(8,4), dpi=100)
    plt.plot(loss_classical, label='Classical')
    plt.plot(loss_quantum_qasm, label='Quantum QASM')
    plt.plot(loss_quantum_quebec, label='Quantum IBM Quebec')
    plt.legend()
    plt.xlabel("Itteration")
    plt.ylabel("Cross Entropy")
    plt.show()