## 1. Import libraries

In [None]:
import os
from os import listdir
import pandas as pd
from collections import Counter
import time

import torch
import torch.nn as nn
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import roc_curve

import pennylane as qml
from pennylane import numpy as np

from embedding import *

## 2. Set directories

In [None]:
# Dataset directory
meta_pca_dir = 
prot_pca_dir = 

In [None]:
# Save directory
save_dir = 

## 3. Define a quantum kernel

In [4]:
dev_1_qubits = qml.device("lightning.qubit", wires=1)
dev_2_qubits = qml.device("lightning.qubit", wires=2)
dev_3_qubits = qml.device("lightning.qubit", wires=3)
dev_4_qubits = qml.device("lightning.qubit", wires=4)
dev_8_qubits = qml.device("lightning.qubit", wires=8)
dev_16_qubits = qml.device("lightning.qubit", wires=16)

In [5]:
projector_1_qubits = np.zeros((2 ** 1, 2 ** 1))
projector_1_qubits[0, 0] = 1

projector_2_qubits = np.zeros((2 ** 2, 2 ** 2))
projector_2_qubits[0, 0] = 1

projector_3_qubits = np.zeros((2 ** 3, 2 ** 3))
projector_3_qubits[0, 0] = 1

projector_4_qubits = np.zeros((2 ** 4, 2 ** 4))
projector_4_qubits[0, 0] = 1

projector_8_qubits = np.zeros((2 ** 8, 2 ** 8))
projector_8_qubits[0, 0] = 1

projector_16_qubits = np.zeros((2 ** 16, 2 ** 16))
projector_16_qubits[0, 0] = 1

- PCA 2 kernels

In [6]:
@qml.qnode(dev_1_qubits)
def PC2_amp_kernel(x1, x2):
    qml.AmplitudeEmbedding(x1, wires=range(1), normalize=True)
    qml.adjoint(qml.AmplitudeEmbedding)(x2, wires=range(1), normalize=True)
    return qml.expval(qml.Hermitian(projector_1_qubits, wires=range(1)))

def kernel_mat_PC2_amp(A, B):
    return np.array([[PC2_amp_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_2_qubits)
def PC2_ang_kernel(x1, x2):
    qml.AngleEmbedding(x1, wires=range(2))
    qml.adjoint(qml.AngleEmbedding)(x2, wires=range(2))
    return qml.expval(qml.Hermitian(projector_2_qubits, wires=range(2)))

def kernel_mat_PC2_ang(A, B):
    return np.array([[PC2_ang_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_2_qubits)
def PC2_ZZ_kernel(x1, x2):
    QuantumEmbedding_2qubits(x1)
    QuantumEmbedding_2qubits_inverse(x2)
    return qml.expval(qml.Hermitian(projector_2_qubits, wires=range(2)))
    
def kernel_mat_PC2_ZZ(A, B):
    return np.array([[PC2_ZZ_kernel(a, b) for b in B] for a in A])

- PCA 4 kernels

In [7]:
@qml.qnode(dev_2_qubits)
def PC4_amp_kernel(x1, x2):
    qml.AmplitudeEmbedding(x1, wires=range(2), normalize=True)
    qml.adjoint(qml.AmplitudeEmbedding)(x2, wires=range(2), normalize=True)
    return qml.expval(qml.Hermitian(projector_2_qubits, wires=range(2)))

def kernel_mat_PC4_amp(A, B):
    return np.array([[PC4_amp_kernel(a, b) for b in B] for a in A])
    
##-------------------------##
@qml.qnode(dev_4_qubits)
def PC4_ang_kernel(x1, x2):
    qml.AngleEmbedding(x1, wires=range(4))
    qml.adjoint(qml.AngleEmbedding)(x2, wires=range(4))
    return qml.expval(qml.Hermitian(projector_4_qubits, wires=range(4)))

def kernel_mat_PC4_ang(A, B):
    return np.array([[PC4_ang_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_4_qubits)
def PC4_ZZ_kernel(x1, x2):
    QuantumEmbedding_4qubits(x1)
    QuantumEmbedding_4qubits_inverse(x2)
    return qml.expval(qml.Hermitian(projector_4_qubits, wires=range(4)))

def kernel_mat_PC4_ZZ(A, B):
    return np.array([[PC2_ZZ_kernel(a, b) for b in B] for a in A])

- PCA 8 kernel

In [None]:
@qml.qnode(dev_3_qubits)
def PC8_amp_kernel(x1, x2):
    qml.AmplitudeEmbedding(x1, wires=range(3), normalize=True)
    qml.adjoint(qml.AmplitudeEmbedding)(x2, wires=range(3), normalize=True)
    return qml.expval(qml.Hermitian(projector_3_qubits, wires=range(3)))

def kernel_mat_PC8_amp(A, B):
    return np.array([[PC8_amp_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_8_qubits)
def PC8_ang_kernel(x1, x2):
    qml.AngleEmbedding(x1, wires=range(8))
    qml.adjoint(qml.AngleEmbedding)(x2, wires=range(8))
    return qml.expval(qml.Hermitian(projector_8_qubits, wires=range(8)))

def kernel_mat_PC8_ang(A, B):
    return np.array([[PC8_ang_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_8_qubits)
def PC8_ZZ_kernel(x1, x2):
    QuantumEmbedding_8qubits(x1)
    QuantumEmbedding_8qubits_inverse(x2)
    return qml.probs(wires=range(8))

 def kernel_mat_PC8_ZZ(A, B):
    return np.array([[PC8_ZZ_kernel(a, b) for b in B] for a in A])


- PCA 16 kernel

In [None]:
@qml.qnode(dev_4_qubits)
def PC16_amp_kernel(x1, x2):
    qml.AmplitudeEmbedding(x1, wires=range(4), normalize=True)
    qml.adjoint(qml.AmplitudeEmbedding)(x2, wires=range(4), normalize=True)
    return qml.expval(qml.Hermitian(projector_4_qubits, wires=range(4)))

def kernel_mat_PC16_amp(A, B):
    return np.array([[PC16_amp_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_16_qubits)
def PC16_ang_kernel(x1, x2):
    qml.AngleEmbedding(x1, wires=range(16))
    qml.adjoint(qml.AngleEmbedding)(x2, wires=range(16))
    return qml.probs(wires=range(16))

def kernel_mat_PC16_ang(A, B):
    return np.array([[PC16_ang_kernel(a, b) for b in B] for a in A])

##-------------------------##
@qml.qnode(dev_16_qubits)
def PC16_ZZ_kernel(x1, x2):
    #print("x1:",x1.shape,"/x2:", x2.shape)
    QuantumEmbedding_16qubits(x1)
    QuantumEmbedding_16qubits_inverse(x2)
    return qml.probs(wires=range(16))

 def kernel_mat_PC16_ZZ(A, B):
    return np.array([[PC16_ZZ_kernel(a, b)[0] for b in B] for a in A])

## 4. Execute QSVM algorithms with the optimized C and kernel functions

In [10]:
meta_filelist = listdir(meta_pca_dir)
prot_filelist = listdir(prot_pca_dir)
print(meta_filelist)
print(prot_filelist)

['GROUP6_PCA_2.csv', 'GROUP1_PCA_16.csv', 'GROUP6_PCA_4.csv', 'GROUP1_PCA_4.csv', 'GROUP7_PCA_2.csv', 'GROUP6_PCA_16.csv', 'GROUP1_PCA_2.csv', 'GROUP7_PCA_4.csv', 'GROUP2_PCA_8.csv', 'GROUP2_PCA_4.csv', 'GROUP7_PCA_8.csv', 'GROUP2_PCA_2.csv', 'GROUP1_PCA_8.csv', 'GROUP2_PCA_16.csv', 'GROUP6_PCA_8.csv', 'GROUP7_PCA_16.csv']
['GROUP6_PCA_2.csv', 'GROUP1_PCA_16.csv', 'GROUP6_PCA_4.csv', 'GROUP1_PCA_4.csv', 'GROUP7_PCA_2.csv', 'GROUP6_PCA_16.csv', 'GROUP1_PCA_2.csv', 'GROUP7_PCA_4.csv', 'GROUP2_PCA_8.csv', 'GROUP2_PCA_4.csv', 'GROUP7_PCA_8.csv', 'GROUP2_PCA_2.csv', 'GROUP1_PCA_8.csv', 'GROUP2_PCA_16.csv', 'GROUP6_PCA_8.csv', 'GROUP7_PCA_16.csv']


In [11]:
dataset = 'prot'

In [12]:
if dataset == 'meta':
    pca_filelist = meta_filelist
    pca_dir = meta_pca_dir
if dataset == 'prot':
    pca_filelist = prot_filelist
    pca_dir = prot_pca_dir

In [13]:
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

In [None]:
# Save directory
save_dir_cf_mat = 
save_dir_auc_score = 
save_dir_auroc_curve = 
save_dir_ba_score = 

In [15]:
for no, pca_file in enumerate(pca_filelist[15:]):
    print(">---------------------------------<")
    print("No:",str(no),"/File:", pca_file)

    group_info = str(pca_file).split('_')[0].replace('GROUP','')
    print("GROUP:", group_info)
    PCA_info = str(pca_file).split('_')[-1].replace('.csv','')
    print("PCA info:", PCA_info)

    if dataset == 'meta':
        if str(group_info) == '1':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=1000, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=1000, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=1000, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=10000.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=10000.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=10000.0, random_state=3, class_weight='balanced')

        if str(group_info) == '2':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')
        
        if str(group_info) == '6':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=10000.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=10000.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=10000.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=100.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=100.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=100.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')

        if str(group_info) == '7':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=10.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=10.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=10.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')
                
    if dataset == 'prot':
        if str(group_info) == '1':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=1000, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=1000, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=1000, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')

        if str(group_info) == '2':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=0.1, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=0.1, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=0.1, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')
        
        if str(group_info) == '6':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=10000.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=10000.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=10000.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=100.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=100.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=100.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')

        if str(group_info) == '7':
            if str(PCA_info) == '2':    
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC2_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC2_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC2_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '4':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC4_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC4_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC4_ZZ, C=1.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '8':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC8_amp, C=100.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC8_ang, C=100.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC8_ZZ, C=100.0, random_state=3, class_weight='balanced')
    
            if str(PCA_info) == '16':
                SVC_amp_kernel = SVC(kernel=kernel_mat_PC16_amp, C=1.0, random_state=1, class_weight='balanced')
                SVC_ang_kernel = SVC(kernel=kernel_mat_PC16_ang, C=1.0, random_state=2, class_weight='balanced')
                SVC_ZZ_kernel = SVC(kernel=kernel_mat_PC16_ZZ, C=1.0, random_state=3, class_weight='balanced')
        
        
    pca_df = pd.read_csv(pca_dir+"/"+str(pca_file))
    print("Shape of PCA file:", pca_df.shape)

    pca_y = pca_df.iloc[:,0]
    pca_X = pca_df.iloc[:,1:]
    print("PCA X:", pca_X.shape, "/PCA Y:", pca_y.shape)

    print("Start 3-Fold CV!")

    for no, [train_index, test_index] in enumerate(skf.split(pca_X, pca_y)):
        print("\n##--------------------------##")
        print(f"[Fold {no+1}] ")
        X_train, X_test = pca_X.iloc[train_index,:], pca_X.iloc[test_index,: ]
        y_train, y_test = pca_y[train_index], pca_y[test_index]
        print("X_train:",X_train.shape,"/y_train:",y_train.shape,"/X_test:",X_test.shape,"/y_test:",y_test.shape)
        print("y_train:", Counter(y_train),"/y_test:",Counter(y_test))
        
        X_train, X_test, y_train, y_test = np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)

        initial_time = time.time()
        SVC_amp_kernel.fit(X_train, y_train)
        amp_time = time.time()
        print("Amp time:", amp_time-initial_time)
        SVC_ang_kernel.fit(X_train, y_train)
        ang_time = time.time()
        print("Ang time:", ang_time-amp_time)
        SVC_ZZ_kernel.fit(X_train, y_train)
        zz_time = time.time()
        print("ZZ time:", zz_time-ang_time)

        pred_amp_kernel = SVC_amp_kernel.predict(X_test)
        pred_ang_kernel = SVC_ang_kernel.predict(X_test)
        pred_ZZ_kernel = SVC_ZZ_kernel.predict(X_test)
        print(Counter(pred_amp_kernel), Counter(pred_ang_kernel), Counter(pred_ZZ_kernel))

        
        cf_matrix_amp = list(confusion_matrix(y_test, pred_amp_kernel, labels=[0, 1]).ravel())
        cf_matrix_ang = list(confusion_matrix(y_test, pred_ang_kernel, labels=[0, 1]).ravel())
        cf_matrix_ZZ = list(confusion_matrix(y_test, pred_ZZ_kernel, labels=[0, 1]).ravel())

        cf_matrices_results=[[str(pca_file), "Amplitude", "Fold-"+str(no+1)]+cf_matrix_amp,
                             [str(pca_file), "Angle", "Fold-"+str(no+1)]+cf_matrix_ang,
                             [str(pca_file), "ZZ", "Fold-"+str(no+1)]+cf_matrix_ZZ]
        cf_matrices_results_df = pd.DataFrame(cf_matrices_results)
        cf_matrices_results_df.to_csv(save_dir_cf_mat+"/cf_matrices_group"+str(group_info)+"_pc"+str(PCA_info)+"_fold"+str(no+1)+".csv", index=False, header=None)

        
        auc_score_amp = roc_auc_score(y_test, pred_amp_kernel)
        auc_score_ang = roc_auc_score(y_test, pred_ang_kernel)
        auc_score_ZZ = roc_auc_score(y_test, pred_ZZ_kernel)

        auc_scores_results=[[str(pca_file), "Amplitude", "Fold-"+str(no+1), auc_score_amp],
                            [str(pca_file), "Angle", "Fold-"+str(no+1), auc_score_ang],
                            [str(pca_file), "ZZ", "Fold-"+str(no+1), auc_score_ZZ]]
        auc_scores_results_df = pd.DataFrame(auc_scores_results)
        auc_scores_results_df.to_csv(save_dir_auc_score+"/auc_score_group"+str(group_info)+"_pc"+str(PCA_info)+"_fold"+str(no+1)+".csv", index=False, header=None)
        

        fpr_amp, tpr_amp, thresholds_amp = roc_curve(y_test, pred_amp_kernel)
        fpr_ang, tpr_ang, thresholds_ang = roc_curve(y_test, pred_ang_kernel)
        fpr_zz, tpr_zz, thresholds_zz = roc_curve(y_test, pred_ZZ_kernel)

        auroc_curves_results=[[str(pca_file), "Amplitude", "Fold-"+str(no+1), fpr_amp, tpr_amp, thresholds_amp],
                              [str(pca_file), "Angle", "Fold-"+str(no+1), fpr_ang, tpr_ang, thresholds_ang],
                              [str(pca_file), "ZZ", "Fold-"+str(no+1), fpr_zz, tpr_zz, thresholds_zz]]
        auroc_curves_results_df = pd.DataFrame(auroc_curves_results)
        auroc_curves_results_df.to_csv(save_dir_auroc_curve+"/auroc_curve_group"+str(group_info)+"_pc"+str(PCA_info)+"_fold"+str(no+1)+".csv", index=False, header=None)


        ba_score_amp = balanced_accuracy_score(y_test, pred_amp_kernel)
        ba_score_ang = balanced_accuracy_score(y_test, pred_ang_kernel)
        ba_score_zz = balanced_accuracy_score(y_test, pred_ZZ_kernel)

        ba_scores_results=[[str(pca_file), "Amplitude", "Fold-"+str(no+1), ba_score_amp],
                           [str(pca_file), "Angle", "Fold-"+str(no+1), ba_score_ang],
                           [str(pca_file), "ZZ", "Fold-"+str(no+1), ba_score_zz]]
        ba_scores_results_df = pd.DataFrame(ba_scores_results)
        ba_scores_results_df.to_csv(save_dir_ba_score+"/ba_score_group"+str(group_info)+"_pc"+str(PCA_info)+"_fold"+str(no+1)+".csv", index=False, header=None)
        
        

        
        #break
    #break


>---------------------------------<
No: 0 /File: GROUP7_PCA_16.csv
GROUP: 7
PCA info: 16
Shape of PCA file: (72, 17)
PCA X: (72, 16) /PCA Y: (72,)
Start 3-Fold CV!

##--------------------------##
[Fold 1] 
X_train: (48, 16) /y_train: (48,) /X_test: (24, 16) /y_test: (24,)
y_train: Counter({0: 42, 1: 6}) /y_test: Counter({0: 20, 1: 4})
Amp time: 22.75872492790222
Ang time: 8.246223211288452
num_A: 48 /num_B: 48
Kernel mat: (48, 48)
ZZ time: 39.30313181877136
num_A: 24 /num_B: 48
Kernel mat: (24, 48)
Counter({np.int64(0): 21, np.int64(1): 3}) Counter({np.int64(0): 23, np.int64(1): 1}) Counter({np.int64(1): 24})

##--------------------------##
[Fold 2] 
X_train: (48, 16) /y_train: (48,) /X_test: (24, 16) /y_test: (24,)
y_train: Counter({0: 41, 1: 7}) /y_test: Counter({0: 21, 1: 3})
Amp time: 23.165879726409912
Ang time: 8.188395977020264
num_A: 48 /num_B: 48
Kernel mat: (48, 48)
ZZ time: 39.00432825088501
num_A: 24 /num_B: 48
Kernel mat: (24, 48)
Counter({np.int64(0): 21, np.int64(1): 3})