In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import neuron
import linear
import time
# library from snntorch
from ttfe import latency 

In [4]:
# check gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
# modify here if you use other datasets
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import random_split

input_dimmension = 784                 # input feature의 개수 (MNIST : 784)
hidden_dimmension = 100                # hidden layer 개수
output_dimmension = 10                 # output label 개수 (MNIST : 10)
batch_sz = 128                         # batch size : dataset의 크기에 따라 조절 
# if you want to use other dataset
#data_path = './WineQT.csv'             # dataset이 저장되어있는 경로 // dataset이 csv파일이고 마지막 feature가 output label이라고 가정


# 임의의 dataset을 사용하기 위해 만든 class
class SNN_Dataset(Dataset):
  def __init__(self, csv_path):
    df = pd.read_csv(csv_path)

    self.x_data = df.iloc[:, :-1].values
    self.y_data = df.iloc[:, -1].values

  def __len__(self):
    return len(self.x_data)
  def __getitem__(self, idx):
    return self.x_data[idx], self.y_data[idx]

In [6]:
# nearly all paramters in this simulation, but it seems no use to store them in a dict...
param = {'G_min': -1.0, # the minimum (maximum) value of conductance, I found if I use real conductance value,#-1.0
        'G_max': 1.0,   # which is on 1e-4 level, the gradience will be too small even in two layer net.#1.0
        'Rd': 10e9,    # this device resistance is mannually set for smaller leaky current? / 5.0e9
        'Cm': 80e-12,   # real capacitance is absolutely larger than this value / 0.8e-10
        'Rs': 1005000,      # this series resistance value is mannually set for larger inject current?
        'Vth': 4.2,     # this is the real device threshould voltage #5.6
        'V_reset': 3.7,
        'dt': 1.75e-4,   # every time step is dt, in the one-order differential equation of neuron
        'T_sim': 20,   # could control total spike number collected
        'dim_in': input_dimmension,
        'dim_h': hidden_dimmension,
        'dim_out': output_dimmension,
        'amp' : 3,    # the gain of TIAs
        'q_bit': 7,    # quantize bit
        'epoch': 100,
        'batch_size': batch_sz,
        'learning_rate': 0.015,
        'data_dir': './MNIST',
        'train_file': 'trainning_log_7bit.txt',
        'test_file': 'test_log.txt',
        'model_dir': 'Model.pth'
}

def rate_encoding(x, T_sim):
    '''
    Encodes the input image pixels into periodic spike trains.
    
    input : x (batch_size, dim_in) - input data
           T_sim (int) - number of time steps for the simulation
    output : (batch_size, dim_in, T_sim) - periodic spikes for each pixel over T_sim time steps
    '''
    
    batch_size, dim_in = x.shape
    spikes = torch.zeros(batch_size, dim_in, T_sim).to(x.device)  # Initialize spikes tensor

    # Data directly obtained through research
    pixel_values = torch.tensor([0, 0.287273, 0.389091, 0.592727, 0.796364, 1]).to(x.device)
    periods = torch.tensor([6.5, 3.1, 2.6, 2.2, 1.9, 1.7]).to(x.device)
    
    period = torch.zeros_like(x)
    for i in range(len(pixel_values) - 1):
        mask = (x >= pixel_values[i]) & (x < pixel_values[i+1])
        # Linearly interpolate periods for those pixels
        period[mask] = periods[i] + (periods[i+1] - periods[i]) * (x[mask] - pixel_values[i]) / (pixel_values[i+1] - pixel_values[i])
    period[x == 1] = 1.7

    for t in range(T_sim):
        spikes[:, :, t] = ((t+1) // period > t // period)
    return spikes


def ttfs_encoding(x, T_sim):
    '''
    Encodes the input image pixels using time-to-first-spike encoding based on provided data.

    input: x (batch_size, dim_in) - input data
           T_sim (int) - number of time steps for the simulation
    output: (batch_size, dim_in, T_sim) - time-to-first-spike encoded spikes
    '''
    
    batch_size, dim_in = x.shape
    spikes = torch.zeros(batch_size, dim_in, T_sim).to(x.device)

    # Data directly obtained through research
    pixel_values = torch.tensor([0, 0.287273, 0.389091, 0.592727, 0.796364, 1]).to(x.device)
    spike_times = torch.tensor([14.3, 12.4, 11.8, 11.4, 11.3, 11.2]).to(x.device) 
    
    spike_times = spike_times - 11.2  # starts after 11.2ms for efficient model training
    spike_time = torch.zeros_like(x)
    
    for i in range(len(pixel_values) - 1):
        mask = (x >= pixel_values[i]) & (x < pixel_values[i+1])
        # Linearly interpolate spike times based on pixel values
        spike_time[mask] = spike_times[i] + (spike_times[i+1] - spike_times[i]) * (x[mask] - pixel_values[i]) / (pixel_values[i+1] - pixel_values[i])
    spike_time[x == 1] = spike_times[-1]
    
    for t in range(T_sim):
        spikes[:, :, t] = (spike_time < t + 1).float() * (spike_time >= t).float()
    return spikes


    
    
class Three_Layer_SNN(nn.Module):
    '''
    This net model contains 2 linear layer, 2 self-defined BatchNorm layer and 2 Neuron layer.

    linear layer: a memristor crossbar on which the MAC operation is implemented.
    BatchNorm layer: a row of TIA as the output interface of the pre-linear layer, normalize the
                    output current to  -2.0~2.0 V voltage.
    neuron layer: nonliear activation, receive input voltage and output spikes, spiking rate is taken
                    in loss computing.
    '''
    def __init__(self, param):
        super().__init__()
        self.linear1 = linear.MAC_Crossbar(param['dim_in'], param['dim_h'],
                                            param['G_min'], param['G_max'], param['q_bit'])
        self.BatchNorm1 = linear.TIA_Norm(param['dim_in'], 0.0, 200.0)    # the paramters of TIA are mannually set for moderate input voltage to neurons
        self.neuron1 = neuron.LIFNeuron(param['batch_size'], param['dim_h'], param['Rd'], param['Cm'],
                                            param['Rs'], param['Vth'], param['V_reset'], param['dt'])
        self.linear2 = linear.MAC_Crossbar(param['dim_h'], param['dim_out'],
                                            param['G_min'], param['G_max'], param['q_bit'])
        self.BatchNorm2 = linear.TIA_Norm(param['dim_h'], 0.0, 200.0)    # same as above
        self.neuron2 = neuron.LIFNeuron(param['batch_size'], param['dim_out'], param['Rd'], param['Cm'],
                                            param['Rs'], param['Vth'], param['V_reset'], param['dt'])

    def forward(self, input_vector):
        total_spikes = 0 # total spike count
        
        out_vector = self.linear1(input_vector)
        out_vector = self.BatchNorm1(out_vector)

        self.neuron1.v = self.neuron1.v.to(device)
        self.neuron2.v = self.neuron2.v.to(device)

        out_vector = self.neuron1(out_vector)
        total_spikes += torch.sum(out_vector == 1).item()
        
        out_vector = self.linear2(out_vector)
        out_vector = self.BatchNorm2(out_vector)
        out_vector = self.neuron2(out_vector)
        total_spikes += torch.sum(out_vector == 1).item()
        
        return out_vector, total_spikes

    def reset_(self):
        '''
        Reset all neurons after one forward pass,
        to ensure the independency of every input image.
        '''
        for item in self.modules():
            if hasattr(item, 'reset'):
                item.reset()

    def quant_(self):
        '''
        The quantization function in pytorch only support int8,
        so we need our own quant function for adjustable quantization precision.
        '''
        for item in self.modules():
            if hasattr(item, 'Gquant_'):
                #debug print：
                #print(item.weight.max())
                item.Gquant_()
                #debug print：
                #print(item.weight.max())



In [7]:
# time-to-first-spike encoding 방식을 통해 encdoing 했을 때 spike가 어떻게 발생하는 지 디버깅하기 위한 부분

# 픽셀 값 생성 (0 ~ 1)
pixel_values = torch.linspace(0, 1, 100).reshape(1, -1).to('cpu')  #(batch_size, dim_in)
print("Pixel Values:", pixel_values)

tau = 10
threshold = 0.01

# snntorch에서 함수를 이용해 ttfs encoding한 결과
spike_time1 = latency(pixel_values, num_steps=param['T_sim'], tau=tau, threshold=threshold, clip=False, normalize=False, linear=False, bypass=True)
spike_time1 = spike_time1.permute(1, 2, 0)  #(1, N, T_sim)
spike_time1_first_spike = torch.zeros_like(spike_time1)
for i in range(spike_time1.shape[1]): 
    first_spike_indices = torch.nonzero(spike_time1[0, i, :])
    if len(first_spike_indices) > 0:
        first_spike_time = first_spike_indices[0, 0].item()
        spike_time1_first_spike[0, i, first_spike_time] = 1 

# 직접 구한 데이터를 이용해 ttfs 방식으로 encoding
spike_time2 = ttfs_encoding(pixel_values, param['T_sim'])

print("\n ttfe spike time with snntorch :")
for i in range(spike_time1_first_spike.shape[1]):
    spike_times = torch.nonzero(spike_time1_first_spike[0, i, :]).flatten()
    print(f"Pixel {pixel_values[0, i].item()}: Spikes at time step(s) {spike_times.tolist()}")

print("\n ttfe with acquired data :")
for i in range(spike_time2.shape[1]):
    spike_times = torch.nonzero(spike_time2[0, i, :]).flatten()
    print(f"Pixel {pixel_values[0, i].item()}: Spikes at time step(s) {spike_times.tolist()}")

Pixel Values: tensor([[0.0000, 0.0101, 0.0202, 0.0303, 0.0404, 0.0505, 0.0606, 0.0707, 0.0808,
         0.0909, 0.1010, 0.1111, 0.1212, 0.1313, 0.1414, 0.1515, 0.1616, 0.1717,
         0.1818, 0.1919, 0.2020, 0.2121, 0.2222, 0.2323, 0.2424, 0.2525, 0.2626,
         0.2727, 0.2828, 0.2929, 0.3030, 0.3131, 0.3232, 0.3333, 0.3434, 0.3535,
         0.3636, 0.3737, 0.3838, 0.3939, 0.4040, 0.4141, 0.4242, 0.4343, 0.4444,
         0.4545, 0.4646, 0.4747, 0.4848, 0.4949, 0.5051, 0.5152, 0.5253, 0.5354,
         0.5455, 0.5556, 0.5657, 0.5758, 0.5859, 0.5960, 0.6061, 0.6162, 0.6263,
         0.6364, 0.6465, 0.6566, 0.6667, 0.6768, 0.6869, 0.6970, 0.7071, 0.7172,
         0.7273, 0.7374, 0.7475, 0.7576, 0.7677, 0.7778, 0.7879, 0.7980, 0.8081,
         0.8182, 0.8283, 0.8384, 0.8485, 0.8586, 0.8687, 0.8788, 0.8889, 0.8990,
         0.9091, 0.9192, 0.9293, 0.9394, 0.9495, 0.9596, 0.9697, 0.9798, 0.9899,
         1.0000]])

 ttfe spike time with snntorch :
Pixel 0.0: Spikes at time step(s) []
Pixel

In [8]:
# Rate Encoding 방식으로 model을 학습85%
# Time_step = 20, learning_rate = 0.015, total_spike = 480000, test_accuracy = 84~85%

trainset = torchvision.datasets.MNIST(root=param['data_dir'], train=True,
                                        download=True, transform=transforms.ToTensor())
testset = torchvision.datasets.MNIST(root=param['data_dir'], train=False,
                                        download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=param['batch_size'],
                                            shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=param['batch_size'],
                                            shuffle=False, drop_last=True, num_workers=2, pin_memory=True)

#Train the SNN with BP
net = Three_Layer_SNN(param).to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(net.parameters(), param['learning_rate'])

start = time.time()
for epoch in range(param['epoch']):
    net.train()
    train_accuracy = []
    
    for img, label in trainloader:
        img = img.reshape(-1, input_dimmension)
        spike_num_img = torch.zeros(param['batch_size'], param['dim_out']).to(device)

        # using gpu
        img, label = img.to(device), label.to(device)
    
        # rate coding method
        spike_time = rate_encoding(img, param['T_sim'])

        net.reset_() #set the neuron voltage as reset voltage
        for t in range(param['T_sim']):
            # rate coding method
            new_img = spike_time[:, :, t]  
            out_spike, total_spike = net(new_img)
            spike_num_img += out_spike 

        spike_rate = spike_num_img/param['T_sim']  # /26
        loss = loss_func(spike_rate, label)

        #net.zero_grad()
        optim.zero_grad()
        loss.backward()
        optim.step()

        with torch.no_grad():
            net.reset_()    # reset the neuron voltage every batch, to ensure independency between batchs
            net.quant_()    # quantize the weights after weight update

        train_accuracy.append((spike_rate.max(1)[1] == label).float().mean().item())
    accuracy_epoch = np.mean(train_accuracy)
    print('tranning epoch %d: the SNN loss is %.6f' %(epoch, loss), end=' ')
    print('trainning accuracy: %.4f' %accuracy_epoch, end=' ')
    
# validation by testset every epoch to see if the network is overfitted
    net.eval()
    validation_accuracy = []
    with torch.no_grad():
        for img_test, label_test in testloader:
            img_test = img_test.reshape(-1, input_dimmension)
            spike_num_img_test = torch.zeros(param['batch_size'], param['dim_out']).to(device)

            # using gpu
            img_test, label_test = img_test.to(device), label_test.to(device)

            # rate coding method
            spike_time = rate_encoding(img_test, param['T_sim'])
            
            total_spike = spike_time.sum()
            net.reset_() #set the neuron voltage as reset voltage
            for t in range(param['T_sim']):
                # rate coding method
                new_test_img = spike_time[:, :, t]  
                out_spike, spike_num = net(new_test_img)
                spike_num_img_test += out_spike 
                total_spike += spike_num
                
            validation_accuracy.append((spike_num_img_test.max(1)[1]==label_test).float().mean().item())
        accuracy_val = np.mean(validation_accuracy)
        print('validation accuracy: %.4f' %accuracy_val)
        print('total spike number is {}' .format(total_spike), end = " ") 
        print('elapsed time : {0:.1f}s' .format(time.time() - start))

    with open(param['train_file'], 'a') as f_t:
        s = str(epoch).ljust(6, ' ') + str(round(loss.item(), 6)).ljust(12, ' ')
        s += str(round(accuracy_epoch, 4)).ljust(10, ' ') + str(round(accuracy_val, 4)).ljust(10, ' ')
        s += str(int(total_spike.item())).ljust(12, ' ') + '\n'
        f_t.write(s)

torch.save(net.state_dict(), param['model_dir'])

tranning epoch 0: the SNN loss is 2.251962 trainning accuracy: 0.1606 validation accuracy: 0.2043
total spike number is 482288.0 elapsed time : 8.5s
tranning epoch 1: the SNN loss is 2.215293 trainning accuracy: 0.2445 validation accuracy: 0.2831
total spike number is 481148.0 elapsed time : 17.2s
tranning epoch 2: the SNN loss is 2.204765 trainning accuracy: 0.3085 validation accuracy: 0.3461
total spike number is 480558.0 elapsed time : 24.9s
tranning epoch 3: the SNN loss is 2.173622 trainning accuracy: 0.4473 validation accuracy: 0.4276
total spike number is 479437.0 elapsed time : 32.7s
tranning epoch 4: the SNN loss is 2.140737 trainning accuracy: 0.5224 validation accuracy: 0.5671
total spike number is 478509.0 elapsed time : 40.4s
tranning epoch 5: the SNN loss is 2.145893 trainning accuracy: 0.5696 validation accuracy: 0.5967
total spike number is 478496.0 elapsed time : 48.1s
tranning epoch 6: the SNN loss is 2.125846 trainning accuracy: 0.5895 validation accuracy: 0.6219
tot

In [12]:
# Time-to-first-spike encoding 
# Time_step = 20, learning_rate = 0.015, total_spike = 210000, test_accuracy = 84%

trainset = torchvision.datasets.MNIST(root=param['data_dir'], train=True,
                                        download=True, transform=transforms.ToTensor())
testset = torchvision.datasets.MNIST(root=param['data_dir'], train=False,
                                        download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=param['batch_size'],
                                            shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=param['batch_size'],
                                            shuffle=False, drop_last=True, num_workers=2, pin_memory=True)

#Train the SNN with BP
net = Three_Layer_SNN(param).to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(net.parameters(), param['learning_rate'])

start = time.time()
for epoch in range(param['epoch']):
    net.train()
    train_accuracy = []
    
    for img, label in trainloader:
        img = img.reshape(-1, input_dimmension)
        spike_num_img = torch.zeros(param['batch_size'], param['dim_out']).to(device)

        # using gpu
        img, label = img.to(device), label.to(device)
        
        # Time-to-first spike coding method
        #spike_time = latency(img, num_steps=param['T_sim'], tau=10, threshold=0.01, clip=False, normalize=False, linear=False, bypass=True)
        spike_time = ttfs_encoding(img, param['T_sim'])
        mask = torch.ones(param['batch_size'], 1).to(device)

        net.reset_() #set the neuron voltage as reset voltage
        for t in range(param['T_sim']):
            # Time-to-first spike coding method
            #new_img = spike_time[t]
            new_img = spike_time[:, :, t]  
            out_spike, _ = net(new_img)
            out_spike *= mask
            spike_num_img += out_spike
            # update mask
            spike_detected = torch.any(out_spike > 0, dim=1, keepdim=True) 
            mask = mask * (~spike_detected) 
        spike_rate = spike_num_img
        loss = loss_func(spike_rate, label)

        #net.zero_grad()
        optim.zero_grad()
        loss.backward()
        optim.step()

        with torch.no_grad():
            net.reset_()    # reset the neuron voltage every batch, to ensure independency between batchs
            net.quant_()    # quantize the weights after weight update

        train_accuracy.append((spike_rate.max(1)[1] == label).float().mean().item())
        
    accuracy_epoch = np.mean(train_accuracy)
    print('tranning epoch %d: the SNN loss is %.6f' %(epoch, loss), end=' ')
    print('trainning accuracy: %.4f' %accuracy_epoch, end=' ')
    
# validation by testset every epoch to see if the network is overfitted
    net.eval()
    validation_accuracy = []
    with torch.no_grad():
        for img_test, label_test in testloader:
            img_test = img_test.reshape(-1, input_dimmension)
            spike_num_img_test = torch.zeros(param['batch_size'], param['dim_out']).to(device)

            # using gpu
            img_test, label_test = img_test.to(device), label_test.to(device)
            
            # Time-to-first spike coding method
            #spike_time = latency(img_test, num_steps=param['T_sim'], tau=10, threshold=0.01, clip=False, normalize=False, linear=False, bypass = True)
            spike_time = ttfs_encoding(img_test, param['T_sim'])
            mask = torch.ones(param['batch_size'], 1).to(device)
            
            total_spike = spike_time.sum()
            net.reset_() #set the neuron voltage as reset voltage
            for t in range(param['T_sim']):
                # Time-to-first spike coding method
                # new_test_img = spike_time[t]
                new_test_img = spike_time[:, :, t]
                out_spike, spike_num = net(new_test_img)
                out_spike *= mask
                spike_num_img_test += out_spike
                
                # update mask
                spike_detected = torch.any(out_spike > 0, dim=1, keepdim=True) 
                mask = mask * (~spike_detected) 
                
                total_spike += spike_num
                
            validation_accuracy.append((spike_num_img_test.max(1)[1]==label_test).float().mean().item())
        accuracy_val = np.mean(validation_accuracy)
        print('validation accuracy: %.4f' %accuracy_val)
        print('total spike number is {}' .format(total_spike), end = " ") 
        print('elapsed time : {0:.1f}s' .format(time.time() - start))

    with open(param['train_file'], 'a') as f_t:
        s = str(epoch).ljust(6, ' ') + str(round(loss.item(), 6)).ljust(12, ' ')
        s += str(round(accuracy_epoch, 4)).ljust(10, ' ') + str(round(accuracy_val, 4)).ljust(10, ' ')
        s += str(int(total_spike.item())).ljust(12, ' ') + '\n'
        f_t.write(s)

torch.save(net.state_dict(), param['model_dir'])

tranning epoch 0: the SNN loss is 1.714737 trainning accuracy: 0.4948 validation accuracy: 0.5894
total spike number is 208787.0 elapsed time : 7.8s
tranning epoch 1: the SNN loss is 1.664072 trainning accuracy: 0.6348 validation accuracy: 0.6699
total spike number is 210679.0 elapsed time : 15.5s
tranning epoch 2: the SNN loss is 1.616650 trainning accuracy: 0.6820 validation accuracy: 0.7195
total spike number is 209045.0 elapsed time : 23.2s
tranning epoch 3: the SNN loss is 1.601045 trainning accuracy: 0.7155 validation accuracy: 0.7260
total spike number is 208427.0 elapsed time : 31.0s
tranning epoch 4: the SNN loss is 1.612205 trainning accuracy: 0.7367 validation accuracy: 0.7642
total spike number is 207594.0 elapsed time : 38.7s
tranning epoch 5: the SNN loss is 1.595583 trainning accuracy: 0.7588 validation accuracy: 0.7830
total spike number is 209105.0 elapsed time : 46.5s
tranning epoch 6: the SNN loss is 1.613802 trainning accuracy: 0.7703 validation accuracy: 0.7988
tot

In [14]:
# Test process after training
net_test = Three_Layer_SNN(param).to(device)
print('Loading Model, please wait......')
net_test.load_state_dict(torch.load(param['model_dir'], weights_only=True))
print('Model loaded successfully!')
list_num_spike = []
for i in range(10):
    list_num_spike.append([0])
    list_num_spike[i].append(torch.zeros(param['dim_out']))

total_correct = 0  # 전체 맞춘 예측 수
total_samples = 0  # 전체 테스트 샘플 수

with torch.no_grad():
    for img_test, label_test in testloader:
        img_test = img_test.reshape(-1, 28 * 28)
        spike_num_img_test = torch.zeros(param['batch_size'], param['dim_out']).to(device)
        net_test.reset_()  # set the neuron voltage as reset voltage

        # using gpu
        img_test, label_test = img_test.to(device), label_test.to(device)

        # Time-to-first spike coding method
        #spike_time = latency(img_test, num_steps=param['T_sim'], tau=10, threshold=0.01, clip=False, normalize=False, linear=False, bypass=True)
        spike_time = ttfs_encoding(img_test, param['T_sim'])
        mask = torch.ones(param['batch_size'], 1).to(device)
        net.reset_()  # set the neuron voltage as reset voltage

        for t in range(param['T_sim']):
            #new_test_img = spike_time[t]
            new_test_img = spike_time[:, :, t]
            out_spike, _ = net_test(new_test_img)
            out_spike *= mask
            spike_num_img_test += out_spike
            # update mask
            spike_detected = torch.any(out_spike > 0, dim=1, keepdim=True)
            mask = mask * (~spike_detected)

        pred_label = F.one_hot(spike_num_img_test.max(1)[1], num_classes=10).to('cpu')  # convert the max neuron output index to onehot vector

        # Accuracy calculation
        pred_label_class = spike_num_img_test.max(1)[1]  # predicted class for each image in the batch
        total_correct += (pred_label_class == label_test).sum().item()  # count correct predictions
        total_samples += label_test.size(0)  # update total samples

        for j in range(label_test.size(0)):
            index = label_test[j]
            list_num_spike[index][0] += 1
            list_num_spike[index][1] += pred_label[j].to('cpu')  # statistics of prediction for every input image

# Calculate and print overall accuracy
accuracy = total_correct / total_samples * 100
print(f'Test Accuracy: {accuracy:.2f}%')

# Save confusion matrix
with open('confusion_matrix.txt', 'a') as f2:
    for i in range(len(list_num_spike)):
        s = str(list_num_spike[i][0]) + ' ' + str(list_num_spike[i][1].numpy()).replace('[', '').replace(']', '') + '\n'
        f2.write(s)

Loading Model, please wait......
Model loaded successfully!
Test Accuracy: 85.44%
