## Application of the classical NN with the classical kernel function

#### 0. Load libraries

In [None]:
import os
from os import listdir
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score as f1_score_calculation
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import data
#import Embedding
import Hybrid_nn

#### 1. Load dataset

In [None]:
target = 'ALDH1'
sampling = '1_6'
feature_reduction = False
classes = [0,1]

quantum_embed = 'ZZ'
n_qubits = 8
kernel = 'RBF'

In [None]:
X_train, X_test, Y_train, Y_test = data.data_load_and_process(dataset='protein', target=target, sampling=sampling, feature_reduction=feature_reduction, classes=classes)

In [None]:
print("X_train:",X_train.shape,"/ X_test:",X_test.shape,"/Y_train:",Y_train.shape,"/Y_test:",Y_test.shape)

In [None]:
print(Counter(Y_train), Counter(Y_test))

In [None]:
X_valid, X_test, Y_valid, Y_test = train_test_split(X_test, Y_test, test_size=0.5, shuffle=True,
                                                            stratify=Y_test, random_state=10)

In [None]:
print("/ X_test:",X_test.shape,"/Y_valid:",Y_valid.shape,"/Y_test:",Y_test.shape)

In [None]:
print("X_train:",X_train.shape, "/Y_train:",Y_train.shape,
      "X_valid:",X_valid.shape, "/Y_test:",Y_valid.shape,
      "X_test:",X_test.shape,"/Y_test:",Y_test.shape)

In [None]:
X_train = torch.from_numpy(X_train).float()
Y_train = torch.from_numpy(Y_train).long()
X_valid = torch.from_numpy(X_valid).float()
Y_valid = torch.from_numpy(Y_valid).long()
X_test  = torch.from_numpy(X_test).float()
Y_test  = torch.from_numpy(Y_test).long()

In [None]:
save_dir = '/Users/jungguchoi/Library/Mobile Documents/com~apple~CloudDocs/1_Post_doc(Cleveland_clinic:2024.10~2025.09)/1_Research_project/3_quantum_embedding_comparison_sequence(2024.09 ~ XXXX.XX)/2_exp/60_Dr_Park_Meeting_and_comments_SEP1725/2_new_classical_counterparts/15_ALDH1_NN_RBF_1_6_ratio/'

#### 2. Early Stopping

In [None]:
class EarlyStopper:
    def __init__(self, patience=40, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

#### 3. Main function for MLP training

In [None]:
batch_size = 512
iterations = 1000
learning_rate = 0.00001

In [None]:
train_dataset = TensorDataset(X_train, Y_train)
valid_dataset = TensorDataset(X_valid, Y_valid)
test_dataset = TensorDataset(X_test, Y_test)

Train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
Valid_dataloader = DataLoader(valid_dataset, batch_size=int(X_valid.shape[0]), shuffle=True, drop_last=True)
Test_dataloader = DataLoader(test_dataset, batch_size=int(X_test.shape[0]), shuffle=True, drop_last=True)

In [None]:
def loss_calculation(proj_data, y):
    n = proj_data.size(0)

    # 1. RBF kernel
    if kernel == 'RBF':
        gamma = 1
        squared_norms = (proj_data ** 2).sum(dim=1).unsqueeze(1)  # shape: (n, 1)
        dists = squared_norms - 2 * proj_data @ proj_data.t() + squared_norms.t()
        dists = torch.clamp(dists, min=0.0)
        K = torch.exp(-gamma * dists)  # shape: (n, n)

    # 2. Linear kernel
    if kernel == 'Linear':
        K = proj_data @ proj_data.t()
    
    y_flat = y.view(-1)  # shape: (n,)
    labels = y_flat.unsqueeze(1) * y_flat.unsqueeze(0)  # shape: (n, n)
    
    loss_matrices = (K - 0.5 * (1 + labels)) ** 2
    
    tri_indices = torch.triu_indices(n, n, offset=1)
    upper_elements = loss_matrices[tri_indices[0], tri_indices[1]]
    
    loss = upper_elements.mean()
    return loss

In [None]:
def train_models(model_name, batch_size, learning_rate):
    train_loss, valid_loss = [], []
    model = Hybrid_nn.get_model(model_name)
    model.train()
    early_stopper = EarlyStopper(patience=40, min_delta=0)
    early_stopped, final_it = False, 0

    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for it in range(iterations):
        for train_inputs, train_targets in Train_dataloader:
            train_proj_data = model(train_inputs)

            loss_training = loss_calculation(train_proj_data, train_targets)
            
            opt.zero_grad()
            loss_training.backward()
            opt.step()
            
        train_loss.append(loss_training.item())
        
        if it % 10 == 0:
            print("-------------------------------------")
            print(f"Iterations: {it} Training Loss: {loss_training.item()}")
            with torch.no_grad():
                
                for valid_inputs, valid_targets in Valid_dataloader:
                    valid_proj_data = model(valid_inputs)
                    loss_validation = loss_calculation(valid_proj_data, valid_targets)
                    print(f"Validation Loss: {loss_validation}")
                    valid_loss.append(loss_validation.item())

                    if early_stopper.early_stop(loss_validation):
                        print("Loss converged!")
                        early_stopped = True
                        final_it = it
                        break

                if early_stopped:
                    break
        if early_stopped:
            break 

    with torch.no_grad():
        for test_inputs, test_targets in Test_dataloader:
            test_proj_data = model(test_inputs)
            loss_test = loss_calculation(test_proj_data, test_targets)
            print(f"Test Loss: {loss_test}")
            
    f = open(f"{save_dir}/{model_name}_LIT-PCBA_{str(target)}_{str(sampling)}_sampling_MLP_{quantum_embed}_{str(n_qubits)}_qubits({kernel}).txt", 'w')
    f.write("Loss History:\n")
    f.write(str(train_loss))
    f.write("\n\n")
    f.write("Validation Loss History:\n")
    f.write(str(valid_loss))
    f.write("\n")
    f.write("\n\n")
    f.write(f"Test Loss: {loss_test}\n")
    if early_stopped == True:
        f.write(f"Validation Loss converged. Early Stopped at iterations {final_it}")
    f.close()
    torch.save(model.state_dict(), f'{save_dir}/{model_name}_LIT-PCBA_{str(target)}_{str(sampling)}_sampling_MLP_{quantum_embed}_{str(n_qubits)}_qubits({kernel}).pt')

    return train_loss, valid_loss

In [None]:
model_name = 'MLP1'
train_loss, valid_loss = train_models(model_name, batch_size, learning_rate)

In [None]:
plt.plot(train_loss)
plt.show()

In [None]:
plt.plot(valid_loss)
plt.show()