In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import random
import time
from tqdm import tqdm  # 
import os
from pathlib import Path
# 
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
torch.backends.cudnn.benchmark = True

print(f"cuda:{torch.cuda.is_available()}")
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        device = torch.device(f"cuda:{i}")
        properties = torch.cuda.get_device_properties(device)
        print(f"GPU {i}:{properties.name}:{properties.total_memory/1024/1024/1024:.2f}GB")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## ###############################################################################

In [None]:

def rand_sparse_matrix(rows, cols, connection_rate):

    assert 0 < connection_rate <= 1, "The connection rate must be between (0,1]"

    # Calculate the number of non-zero elements
    num_elements = rows * cols
    num_nonzero = int(num_elements * connection_rate)

    # Randomly generate the position of non-zero elements
    row_indices = torch.randint(0, rows, (num_nonzero,))
    col_indices = torch.randint(0, cols, (num_nonzero,))

    # Stack row and column indices into a two-dimensional tensor
    indices = torch.stack((row_indices, col_indices))

    # Randomly generate values for non-zero elements
    values = torch.rand(num_nonzero)*2-1

    # Create Sparse Matrix
    sparse_matrix = torch.sparse_coo_tensor(indices, values, (rows, cols))

    return sparse_matrix


def Aindx2Nrange(numnodes, div, Aindx):
    node_indx = []
    nodes_per_div = numnodes // abs(div)
    if isinstance(Aindx, list) == False: Aindx = [Aindx]

    for indx in Aindx:
        # Calculate the start and end indices for the specified division
        start_idx = (indx-1) * nodes_per_div
        end_idx = indx * nodes_per_div - 1
        
        # Create the output range
        node_indx.extend(range(start_idx, end_idx + 1))


    return node_indx


def adam_update(m, v, dw, beta1, beta2, t, epsilon=1e-8):
    # Update first-order moment estimation
    m = beta1 * m + (1 - beta1) * dw
    v = beta2 * v + (1 - beta2) * (dw ** 2)
    
    # First and Second Order Moment Estimation for Deviation Correction Calculation
    m_corr = m / (1 - beta1 ** t)
    v_corr = v / (1 - beta2 ** t)
    
    # Calculate update value
    update = m_corr / (torch.sqrt(v_corr) + epsilon)
    
    return update, m, v

def linear(x):
    return x

def linear_d(x):
    return torch.ones_like(x)

def tanh(x):
    return torch.tanh(x)

def tanh_d(x):
    return 1 - torch.tanh(x) ** 2  

def hard_sigmoid(x):
    return torch.clamp(x,0,1)

def hard_sigmoid_d(x):
    return torch.where((x >= 0) & (x <= 1), torch.ones_like(x), torch.zeros_like(x))

def relu6(x):
    return torch.max(torch.min(x, torch.ones_like(x) * 6), torch.zeros_like(x))  

def relu6_d(x):
    return torch.where((x > 0) & (x <= 6), torch.ones_like(x), torch.zeros_like(x))

def softmax(x):
    x = x - torch.max(x, dim=1, keepdim=True)[0]  # Subtract the maximum value to prevent overflow
    return torch.exp(x) / torch.sum(torch.exp(x), dim=1, keepdim=True)



# One-hot 
def one_hot(labels, n_out):
    one_hot_labels = torch.zeros(labels.size(0), n_out)
    one_hot_labels[torch.arange(labels.size(0)), labels] = 1
    return one_hot_labels

def cross_entropy_loss(output, target):
    epsilon = 1e-8  # avoid log(0)
    output = torch.clamp(output, epsilon, 1. - epsilon)  # Limit the output to [epsilon, 1-epsilon] 
    return -torch.mean(torch.sum(target * torch.log(output), dim=1))


In [None]:
class FRNN_ML4(torch.nn.Module):
    def __init__(self, config):
        super(FRNN_ML4, self).__init__()
        self.device = config.device

        self.f = config.f
        self.feedbackLearning = config.feedbackLearning
        self.RNNLearning = config.RNNLearning
        self.RNNBiasLearning = config.RNNBiasLearning

        self.H_size = config.H_size
        self.NperU = config.NperU
        self.NU = config.NU 

        self.t2sta = config.t2sta
        self.RNNSR = config.RNNSR
        self.Feedforwardsc = config.Feedforwardsc
        self.Feedbacksc = config.Feedbacksc
        self.batch_size = config.batch_size  # 

        # 
        self.W3L = True

        # Wr
        self.Wr_1 = rand_sparse_matrix(config.H_size, config.H_size, config.RNNCR).to_dense()
        self.Wr_1_s = self.Wr_1 !=0 
        self.Wr_1 = self.Wr_1 / torch.abs( torch.linalg.eigvals(self.Wr_1)).max() * self.RNNSR 
        self.Wr_1 = self.Wr_1.to(self.device)
        self.bias_1 = torch.zeros(config.H_size,device=self.device) 
        self.Wr_1_s = self.Wr_1_s.to(self.device) 


        self.Wr_2 = rand_sparse_matrix(config.H_size, config.H_size, config.RNNCR).to_dense()
        self.Wr_2_s = self.Wr_2 !=0
        self.Wr_2 = self.Wr_2 / torch.abs( torch.linalg.eigvals(self.Wr_2)).max() * self.RNNSR
        self.Wr_2 = self.Wr_2.to(self.device)
        self.bias_2 = torch.zeros(config.H_size,device=self.device)
        self.Wr_2_s = self.Wr_2_s.to(self.device) 

        self.Wr_3 = rand_sparse_matrix(config.H_size, config.H_size, config.RNNCR).to_dense()
        self.Wr_3_s = self.Wr_3 !=0
        self.Wr_3 = self.Wr_3 / torch.abs( torch.linalg.eigvals(self.Wr_3)).max() * self.RNNSR
        self.Wr_3 = self.Wr_3.to(self.device)
        self.bias_3 = torch.zeros(config.H_size,device=self.device) 
        self.Wr_3_s = self.Wr_3_s.to(self.device) 

        self.Wr_4 = rand_sparse_matrix(config.H_size, config.H_size, config.RNNCR).to_dense()
        self.Wr_4_s = self.Wr_4 !=0
        self.Wr_4 = self.Wr_4 / torch.abs( torch.linalg.eigvals(self.Wr_4)).max() * self.RNNSR
        self.Wr_4 = self.Wr_4.to(self.device)
        self.bias_4 = torch.zeros(config.H_size,device=self.device)  
        self.Wr_4_s = self.Wr_4_s.to(self.device) 

        #
        self.Wb_f_1 = torch.randn(config.num_classes, config.NperU,device=self.device) * self.Feedbacksc 
        self.Wb_f_2 = torch.randn(config.num_classes, config.NperU,device=self.device) * self.Feedbacksc  
        self.Wb_f_3 = torch.randn(config.num_classes, config.NperU,device=self.device) * self.Feedbacksc  

        self.Wb_1 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedbacksc
        self.Wb_2 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedbacksc
        self.Wb_3 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedbacksc
        self.Wb_4 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedbacksc

        self.WinX_1 = torch.randn(config.input_size, config.NperU,device=self.device) * self.Feedforwardsc
        self.WinX_2 = torch.randn(config.input_size, config.NperU,device=self.device) * self.Feedforwardsc
        self.WinX_3 = torch.randn(config.input_size, config.NperU,device=self.device) * self.Feedforwardsc

        self.Win_1 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedforwardsc
        self.Win_2 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedforwardsc  
        self.Win_3 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedforwardsc
        self.Win_4 = torch.randn(config.NperU, config.NperU,device=self.device) * self.Feedforwardsc  

        self.Wout_1 = torch.randn(config.NperU, config.num_classes,device=self.device) * self.Feedforwardsc
        self.Wout_2 = torch.randn(config.NperU, config.num_classes,device=self.device) * self.Feedforwardsc
        self.Wout_3 = torch.randn(config.NperU, config.num_classes,device=self.device) * self.Feedforwardsc
        self.bias_f_1 = torch.zeros(config.num_classes,device=self.device)
        self.bias_f_2 = torch.zeros(config.num_classes,device=self.device)
        self.bias_f_3 = torch.zeros(config.num_classes,device=self.device)

        #
        self.beta1 = 1
        self.gamma1 = config.gamma1
        self.alpha1 = config.alpha1  # 
        self.lambda1 = config.lambda1

        # Adam
        # 
        self.opt_m_Wr_1, self.opt_v_Wr_1 = torch.zeros_like(self.Wr_1), torch.zeros_like(self.Wr_1)
        self.opt_m_b_1, self.opt_v_b_1 = torch.zeros_like(self.bias_1), torch.zeros_like(self.bias_1)
        self.opt_m_Wr_2, self.opt_v_Wr_2 = torch.zeros_like(self.Wr_2), torch.zeros_like(self.Wr_2)
        self.opt_m_b_2, self.opt_v_b_2 = torch.zeros_like(self.bias_2), torch.zeros_like(self.bias_2)
        self.opt_m_Wr_3, self.opt_v_Wr_3 = torch.zeros_like(self.Wr_3), torch.zeros_like(self.Wr_3)  
        self.opt_m_b_3, self.opt_v_b_3 = torch.zeros_like(self.bias_3), torch.zeros_like(self.bias_3)
        self.opt_m_Wr_4, self.opt_v_Wr_4 = torch.zeros_like(self.Wr_4), torch.zeros_like(self.Wr_4)  
        self.opt_m_b_4, self.opt_v_b_4 = torch.zeros_like(self.bias_4), torch.zeros_like(self.bias_4)  

        self.opt_m_WinX_1, self.opt_v_WinX_1 = torch.zeros_like(self.WinX_1), torch.zeros_like(self.WinX_1)
        self.opt_m_WinX_2, self.opt_v_WinX_2 = torch.zeros_like(self.WinX_2), torch.zeros_like(self.WinX_2)

        self.opt_m_Win_1, self.opt_v_Win_1 = torch.zeros_like(self.Win_1), torch.zeros_like(self.Win_1)
        self.opt_m_Win_2, self.opt_v_Win_2 = torch.zeros_like(self.Win_2), torch.zeros_like(self.Win_2)  
        self.opt_m_Win_3, self.opt_v_Win_3 = torch.zeros_like(self.Win_3), torch.zeros_like(self.Win_3)
        self.opt_m_Win_4, self.opt_v_Win_4 = torch.zeros_like(self.Win_4), torch.zeros_like(self.Win_4) 


        # out
        self.opt_m_Wout_1, self.opt_v_Wout_1 = torch.zeros_like(self.Wout_1), torch.zeros_like(self.Wout_1)
        self.opt_m_Wout_2, self.opt_v_Wout_2 = torch.zeros_like(self.Wout_2), torch.zeros_like(self.Wout_2)
        self.opt_m_bias_f_1, self.opt_v_bias_f_1 = torch.zeros_like(self.bias_f_1), torch.zeros_like(self.bias_f_1)
        self.opt_m_bias_f_2, self.opt_v_bias_f_2 = torch.zeros_like(self.bias_f_2), torch.zeros_like(self.bias_f_2)
        

        self.opt_beta1, self.opt_beta2 = 0.9, 0.999
        self.opt_epsilon = 1e-8
        self.opt_eta = config.learning_rate
        self.opt_t1 = 0  
        self.opt_t2 = 0  

    def forward1(self, x):
        bs = x.size(0)
        self.z_1 = torch.zeros(bs , self.bias_1.size(0),device=self.device)
        self.cbias_1 = torch.zeros(bs , self.bias_1.size(0),device=self.device)
        self.cbias_1[:, Aindx2Nrange(self.H_size, self.NU, 1)] = x.mm(self.WinX_1) 
        self.z_1 = self.It2sta(self.Wr_1, self.z_1, self.bias_1 + self.cbias_1, self.t2sta)

        self.z_2 = torch.zeros(bs , self.bias_2.size(0),device=self.device)
        self.cbias_2 = torch.zeros(bs , self.bias_2.size(0),device=self.device)
        self.cbias_2[:, Aindx2Nrange(self.H_size, self.NU, 1)] = self.z_1[:, Aindx2Nrange(self.H_size, self.NU, 3)].mm(self.Win_1) 
        self.z_2 = self.It2sta(self.Wr_2, self.z_2, self.bias_2 + self.cbias_2, self.t2sta)

        self.zfh1 = self.z_2[:, Aindx2Nrange(self.H_size, self.NU, 3)].mm(self.Wout_1) + self.bias_f_1
        self.zf1 = softmax(self.zfh1)
        # self.zf1 = self.f(self.zfh1)
        return self.zf1
    
    def forward2(self, x):
        bs = x.size(0)

        self.z_3 = torch.zeros(bs , self.bias_3.size(0),device=self.device) 
        self.cbias_3 = torch.zeros(bs , self.bias_3.size(0),device=self.device)
        self.cbias_3[:, Aindx2Nrange(self.H_size, self.NU, 5)] = x.mm(self.WinX_2) 
        self.z_3 = self.It2sta(self.Wr_3, self.z_3, self.bias_3 + self.cbias_3, self.t2sta)

        self.z_4 = torch.zeros(bs , self.bias_4.size(0),device=self.device) 
        self.cbias_4 = torch.zeros(bs , self.bias_4.size(0),device=self.device)
        self.cbias_4[:, Aindx2Nrange(self.H_size, self.NU, 1)] = self.z_3[:, Aindx2Nrange(self.H_size, self.NU, 2)].mm(self.Win_3)  # 
        self.z_4 = self.It2sta(self.Wr_4, self.z_4, self.bias_4 + self.cbias_4, self.t2sta)

        self.zfh2 = self.z_4[:, Aindx2Nrange(self.H_size, self.NU, 3)].mm(self.Wout_2) + self.bias_f_2
        self.zf2 = softmax(self.zfh2)
        # self.zf2 = self.f(self.zfh2)
        return self.zf2

    def backward1(self, x, y, output):
        bs = x.size(0)   

        self.e_f_1 = output - y
        # L2
        self.cbias_2[:, Aindx2Nrange(self.H_size, self.NU, 4)] = -self.beta1 * self.e_f_1.mm(self.Wb_f_1)
        self.y_2 = self.z_2
        self.y_2 = self.It2sta(self.Wr_2, self.y_2, self.bias_2 + self.cbias_2, self.t2sta)
        self.e_2 = self.z_2 - self.y_2

        # L4
        self.z_4 = torch.zeros(bs , self.bias_4.size(0),device=self.device)
        self.cbias_4 = torch.zeros(bs , self.bias_4.size(0),device=self.device)
        self.z_4 = self.It2sta(self.Wr_4, self.z_4, self.bias_4 + self.cbias_4, self.t2sta) #It2staf

        self.cbias_4[:, Aindx2Nrange(self.H_size, self.NU, 5)] = -self.beta1 * self.e_f_1.mm(self.Wb_f_3) 
        self.y_4 = self.z_4
        self.y_4 = self.It2sta(self.Wr_4, self.y_4, self.bias_4 + self.cbias_4, self.t2sta) #
        self.e_4 = self.z_4 - self.y_4

        # L3
        self.z_3 = torch.zeros(bs , self.bias_3.size(0),device=self.device)
        self.cbias_3 = torch.zeros(bs , self.bias_3.size(0),device=self.device)
        self.z_3 = self.It2sta(self.Wr_3, self.z_3, self.bias_3 + self.cbias_3, self.t2sta) #

        self.cbias_3[:, Aindx2Nrange(self.H_size, self.NU, 3)] = -self.beta1 * self.e_4[:, Aindx2Nrange(self.H_size, self.NU, 6)].mm(self.Wb_3)
        self.y_3 = self.z_3
        self.y_3 = self.It2sta(self.Wr_3, self.y_3, self.bias_3 + self.cbias_3, self.t2sta) #
        self.e_3 = self.z_3 - self.y_3

        # L1
        self.cbias_1[:, Aindx2Nrange(self.H_size, self.NU, 6)] = -self.beta1 * self.e_3[:, Aindx2Nrange(self.H_size, self.NU, 4)].mm(self.Wb_2)
        self.y_1 = self.z_1
        self.y_1 = self.It2sta(self.Wr_1, self.y_1, self.bias_1 + self.cbias_1, self.t2sta)
        self.e_1 = self.z_1 - self.y_1

        ## 
        self.dWinX_1 = x.t().mm(self.e_1[:, Aindx2Nrange(self.H_size, self.NU, 1)]) / self.batch_size
        self.dWin_1 = self.z_1[:, Aindx2Nrange(self.H_size, self.NU, 3)].t().mm(self.e_2[:, Aindx2Nrange(self.H_size, self.NU, 1)]) / self.batch_size
        self.dWout_1 = self.z_2[:, Aindx2Nrange(self.H_size, self.NU, 3)].t().mm(self.e_f_1) / self.batch_size
        self.dbias_f_1 = self.e_f_1.sum(0) / self.batch_size

        if self.RNNLearning: 
            self.dWr_1 = self.z_1.t().mm(self.e_1) / self.batch_size
            self.dWr_2 = self.z_2.t().mm(self.e_2) / self.batch_size
            
        if self.RNNBiasLearning:
            self.dbias_1 = self.e_1.sum(0) / self.batch_size
            self.dbias_2 = self.e_2.sum(0) / self.batch_size
                
        self.opt_t1 += 1
        dWinX_1, self.opt_m_WinX_1, self.opt_v_WinX_1    = adam_update(self.opt_m_WinX_1, self.opt_v_WinX_1, self.dWinX_1, self.opt_beta1, self.opt_beta2, self.opt_t1)
        self.WinX_1 -= self.opt_eta * dWinX_1 + self.WinX_1 * self.lambda1  

        dWin_1, self.opt_m_Win_1, self.opt_v_Win_1      = adam_update(self.opt_m_Win_1, self.opt_v_Win_1,   self.dWin_1, self.opt_beta1, self.opt_beta2, self.opt_t1)
        self.Win_1 -= self.opt_eta * dWin_1 + self.Win_1 * self.lambda1 

        dWout_1, self.opt_m_Wout_1, self.opt_v_Wout_1    = adam_update(self.opt_m_Wout_1, self.opt_v_Wout_1, self.dWout_1, self.opt_beta1, self.opt_beta2, self.opt_t1)
        self.Wout_1 -= self.opt_eta * dWout_1 + self.Wout_1 * self.lambda1 



        if self.RNNLearning: 
            dWr_1, self.opt_m_Wr_1, self.opt_v_Wr_1          = adam_update(self.opt_m_Wr_1, self.opt_v_Wr_1,     self.dWr_1, self.opt_beta1, self.opt_beta2, self.opt_t1)
            self.Wr_1 -= self.opt_eta * dWr_1 * self.Wr_1_s * self.alpha1 + self.Wr_1 * self.lambda1 
            dWr_2, self.opt_m_Wr_2, self.opt_v_Wr_2         = adam_update(self.opt_m_Wr_2, self.opt_v_Wr_2,     self.dWr_2, self.opt_beta1, self.opt_beta2, self.opt_t1)
            self.Wr_2 -= self.opt_eta * dWr_2 * self.Wr_2_s * self.alpha1 + self.Wr_2 * self.lambda1 

        if self.RNNBiasLearning:
            dbias_1, self.opt_m_b_1, self.opt_v_b_1          = adam_update(self.opt_m_b_1, self.opt_v_b_1,       self.dbias_1, self.opt_beta1, self.opt_beta2, self.opt_t1)
            self.bias_1 -= self.opt_eta * dbias_1 + self.bias_1 * self.lambda1
            dbias_2, self.opt_m_b_2, self.opt_v_b_2         = adam_update(self.opt_m_b_2, self.opt_v_b_2,       self.dbias_2, self.opt_beta1, self.opt_beta2, self.opt_t1)
            self.bias_2 -= self.opt_eta * dbias_2 + self.bias_2 * self.lambda1 
            
            dbias_f_1, self.opt_m_bias_f_1, self.opt_v_bias_f_1    = adam_update(self.opt_m_bias_f_1, self.opt_v_bias_f_1, self.dbias_f_1, self.opt_beta1, self.opt_beta2, self.opt_t1)
            self.bias_f_1 -= self.opt_eta * dbias_f_1 + self.bias_f_1 * self.lambda1 

        
    def backward2(self, x, y, output):
       
        self.e_f_2 = output - y

        self.cbias_4[:, Aindx2Nrange(self.H_size, self.NU, 4)] = -self.beta1 * self.e_f_2.mm(self.Wb_f_2) 
        self.y_4 = self.z_4
        self.y_4 = self.It2sta(self.Wr_4, self.y_4, self.bias_4 + self.cbias_4, self.t2sta)
        self.e_4 = self.z_4 - self.y_4

        # L3
        self.cbias_3[:, Aindx2Nrange(self.H_size, self.NU, 3)] = -self.beta1 * self.e_4[:, Aindx2Nrange(self.H_size, self.NU, 6)].mm(self.Wb_3)
        self.y_3 = self.z_3
        self.y_3 = self.It2sta(self.Wr_3, self.y_3, self.bias_3 + self.cbias_3, self.t2sta)
        self.e_3 = self.z_3 - self.y_3

        ## 
        self.dWinX_2 = x.t().mm(self.e_3[:, Aindx2Nrange(self.H_size, self.NU, 5)]) / self.batch_size
        self.dWin_3 = self.z_3[:, Aindx2Nrange(self.H_size, self.NU, 2)].t().mm(self.e_4[:, Aindx2Nrange(self.H_size, self.NU, 1)]) / self.batch_size
        self.dWout_2 = self.z_4[:, Aindx2Nrange(self.H_size, self.NU, 3)].t().mm(self.e_f_2) / self.batch_size
        self.dbias_f_2 = self.e_f_2.sum(0) / self.batch_size
        
        
        if self.RNNLearning: 
            self.dWr_3 = self.z_3.t().mm(self.e_3) / self.batch_size
            self.dWr_4 = self.z_4.t().mm(self.e_4) / self.batch_size   

        if self.RNNBiasLearning:
            self.dbias_3 = self.e_3.sum(0) / self.batch_size
            self.dbias_4 = self.e_4.sum(0) / self.batch_size


        self.opt_t2 += 1
        dWinX_2, self.opt_m_WinX_2, self.opt_v_WinX_2    = adam_update(self.opt_m_WinX_2, self.opt_v_WinX_2, self.dWinX_2, self.opt_beta1, self.opt_beta2, self.opt_t2)
        self.WinX_2 -= self.opt_eta * dWinX_2 + self.WinX_2 * self.lambda1
        
        dWin_3, self.opt_m_Win_3, self.opt_v_Win_3      = adam_update(self.opt_m_Win_3, self.opt_v_Win_3,   self.dWin_3, self.opt_beta1, self.opt_beta2, self.opt_t2)
        self.Win_3 -= self.opt_eta * dWin_3 + self.Win_3 * self.lambda1 

        dWout_2, self.opt_m_Wout_2, self.opt_v_Wout_2    = adam_update(self.opt_m_Wout_2, self.opt_v_Wout_2, self.dWout_2, self.opt_beta1, self.opt_beta2, self.opt_t2)
        
        self.Wout_2 -= self.opt_eta * dWout_2 + self.Wout_2 * self.lambda1 

        if self.RNNLearning: 
            dWr_3, self.opt_m_Wr_3, self.opt_v_Wr_3         = adam_update(self.opt_m_Wr_3, self.opt_v_Wr_3,     self.dWr_3, self.opt_beta1, self.opt_beta2, self.opt_t2)
            self.Wr_3 -= self.opt_eta * dWr_3 * self.Wr_3_s * self.alpha1 + self.Wr_3 * self.lambda1
            dWr_4, self.opt_m_Wr_4, self.opt_v_Wr_4         = adam_update(self.opt_m_Wr_4, self.opt_v_Wr_4,     self.dWr_4, self.opt_beta1, self.opt_beta2, self.opt_t2)
            self.Wr_4 -= self.opt_eta * dWr_4 * self.Wr_4_s * self.alpha1 + self.Wr_4 * self.lambda1

        if self.RNNBiasLearning:
            dbias_3, self.opt_m_b_3, self.opt_v_b_3         = adam_update(self.opt_m_b_3, self.opt_v_b_3,       self.dbias_3, self.opt_beta1, self.opt_beta2, self.opt_t2)
            self.bias_3 -= self.opt_eta * dbias_3 + self.bias_3 * self.lambda1 
            dbias_4, self.opt_m_b_4, self.opt_v_b_4         = adam_update(self.opt_m_b_4, self.opt_v_b_4,       self.dbias_4, self.opt_beta1, self.opt_beta2, self.opt_t2)
            self.bias_4 -= self.opt_eta * dbias_4 + self.bias_4 * self.lambda1 

            dbias_f_2, self.opt_m_bias_f_2, self.opt_v_bias_f_2    = adam_update(self.opt_m_bias_f_2, self.opt_v_bias_f_2, self.dbias_f_2, self.opt_beta1, self.opt_beta2, self.opt_t2)
            self.bias_f_2 -= self.opt_eta * dbias_f_2 + self.bias_f_2 * self.lambda1 



    def It2sta(self, Wr, h, bias, itsta):
        
        for indx in range(itsta):
            h = self.f(h.mm(Wr)+bias)
        
        return h
    
    def It2staf(self, Wr, h, bias, itsta):

        for indx in range(itsta):
            h = (h.mm(Wr)+bias)
        
        return h
    


In [None]:
## ###############################################################################
# 
class Config:
    def __init__(self):
        self.device = device
        self.input_size = 28 * 28  # 
        self.num_classes = 10      # ~9ï¼‰
        self.H_size = 1536     # Number of neurons in each hidden layer

        self.NperU = 256 # Number of neurons in each block
        self.NU = self.H_size // self.NperU 
        self.f = tanh
        self.df = tanh_d

        self.RNNCR = 0.25
        self.RNNSR = 0.25
        self.Feedforwardsc = 0.1
        self.Feedbacksc = 0.5
        self.t2sta = 8


        self.num_epochs = 100      # training epochs
        self.batch_size = 500     # batch size

        self.learning_rate = 1e-3  # learning rate

        self.alpha1 = 1e-4 # RNN weight updating coe
        
        self.lambda1 = 0e-5 # L2 regular
        self.feedbackLearning = False
        self.gamma1 = 0

        self.RNNLearning = False ## whether the weights in RNN learn 
        self.RNNBiasLearning = False ## whether the bias is used


config = Config()
# 

transform = transforms.Compose([
    transforms.ToTensor(),  # 
])

train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform)

train_loader1 = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader1 = DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=False)
train_loader1_gpu = [(data.to(device), target.to(device)) for data, target in train_loader1]
test_loader1_gpu = [(data.to(device), target.to(device)) for data, target in test_loader1]

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)

train_loader2 = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader2 = DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=False)
train_loader2_gpu = [(data.to(device), target.to(device)) for data, target in train_loader2]
test_loader2_gpu = [(data.to(device), target.to(device)) for data, target in test_loader2]



res_path = Path("./Reusing/B")
res_path.mkdir(parents=True, exist_ok=True)
taskinfo = f'FRNN_ML_PathB_RNNLearning{config.RNNLearning}'

multitest = 5
res = np.zeros((multitest,6,config.num_epochs))

In [None]:
for imultitest in range(multitest):
    model = FRNN_ML4(config).to(device)


    best_y_true = None
    best_y_pred = None
    ###### Training #######

    for epoch in range(config.num_epochs):
        losssum = 0
        random.shuffle(train_loader1_gpu)
        random.shuffle(train_loader2_gpu)
        with tqdm(total=len(train_loader2_gpu), desc=f'Epoch {epoch + 1}/{config.num_epochs}: ', unit='batch', ncols=90, mininterval=1) as pbar:
            for i, (images, labels) in enumerate(train_loader2_gpu):
                # 
                images = images.view(-1, config.input_size)
                labels_one_hot = one_hot(labels, config.num_classes)

                # 
                outputs = model.forward2(images)

                loss = cross_entropy_loss(outputs, labels_one_hot)
                losssum += loss.item()
                
                ##
                model.backward2(images, labels_one_hot, outputs)

                pbar.set_postfix({'loss': f'{losssum/(i+1):.6f}'})  
                pbar.update(1)
        losssum = 0 
        with tqdm(total=len(train_loader1_gpu), desc=f'Epoch {epoch + 1}/{config.num_epochs}: ', unit='batch', ncols=90, mininterval=1) as pbar:
            for i, (images, labels) in enumerate(train_loader1_gpu):
                # 
                images = images.view(-1, config.input_size)
                labels_one_hot = one_hot(labels, config.num_classes)

                # 
                outputs = model.forward1(images)

                loss = cross_entropy_loss(outputs, labels_one_hot)
                losssum += loss.item()
                
                ##
                model.backward1(images, labels_one_hot, outputs)

                pbar.set_postfix({'loss': f'{losssum/(i+1):.6f}'})  
                pbar.update(1)


        train_correct = 0
        train_total = 0
        losssum = 0
        with torch.no_grad():
            for images, labels in train_loader1_gpu:
                images = images.view(-1, config.input_size)
                labels_one_hot = one_hot(labels, config.num_classes)
                outputs = model.forward1(images)
                _, predicted = torch.max(outputs, 1)
                loss = cross_entropy_loss(outputs, labels_one_hot)
                losssum += loss.item()
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

        train_accuracy = train_correct / train_total
        
        test_correct = 0
        test_total = 0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for images, labels in test_loader1_gpu:
                images = images.view(-1, config.input_size)
                outputs = model.forward1(images)
                _, predicted = torch.max(outputs, 1)
                y_true.extend(labels.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()

        test_accuracy = test_correct / test_total
        res[imultitest][0][epoch] = losssum
        res[imultitest][1][epoch] = train_accuracy
        res[imultitest][2][epoch] = test_accuracy

        train_correct = 0
        train_total = 0
        losssum = 0
        with torch.no_grad():
            for images, labels in train_loader2_gpu:
                images = images.view(-1, config.input_size)
                labels_one_hot = one_hot(labels, config.num_classes)
                outputs = model.forward2(images)
                _, predicted = torch.max(outputs, 1)
                loss = cross_entropy_loss(outputs, labels_one_hot)
                losssum += loss.item()
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

        train_accuracy = train_correct / train_total
        

        test_correct = 0
        test_total = 0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for images, labels in test_loader2_gpu:
                images = images.view(-1, config.input_size)
                outputs = model.forward2(images)
                _, predicted = torch.max(outputs, 1)
                y_true.extend(labels.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()

        test_accuracy = test_correct / test_total
        res[imultitest][3][epoch] = losssum
        res[imultitest][4][epoch] = train_accuracy
        res[imultitest][5][epoch] = test_accuracy

        # 
        current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        print(f'\t Loss: {res[imultitest][0][epoch] :.3f},\t'
            f'Train : {res[imultitest][1][epoch] * 100:.2f}%,\t'
            f'Test : {res[imultitest][2][epoch] * 100:.2f}%,\t'
            f'Loss: {res[imultitest][3][epoch] :.3f},\t'
            f'Train : {res[imultitest][4][epoch] * 100:.2f}%,\t'
            f'Test : {res[imultitest][5][epoch] * 100:.2f}%,\t'
            f'Current Time: {current_time}')

    # best result
    print(f"{taskinfo} Time: {current_time}, Epochs: {config.num_epochs}, Learning Rate: {config.learning_rate}, "
        f"FMNIST Best: {max(res[imultitest][1]) * 100:.2f}%  {max(res[imultitest][2]) * 100:.2f}%\n",
        f"MNIST Best: {max(res[imultitest][4]) * 100:.2f}%  {max(res[imultitest][5]) * 100:.2f}%\n")

    current_time = time.strftime("%Y-%m-%d_%H%M%S", time.localtime())

    with open(res_path/"res.txt", "a") as f:
        f.write(f"{taskinfo} Time: {current_time}, Epochs: {config.num_epochs}, Learning Rate: {config.learning_rate}, "
        f"FMNIST Best: {max(res[imultitest][1]) * 100:.2f}%  {max(res[imultitest][2]) * 100:.2f}%\n"
        f"MNIST Best: {max(res[imultitest][4]) * 100:.2f}%  {max(res[imultitest][5]) * 100:.2f}%\n")

        for attr_name in dir(config):
            if not attr_name.startswith('__') and not callable(getattr(config, attr_name)):  
                attr_value = getattr(config, attr_name)
                f.write(f'{attr_name}: {attr_value}  ')  
        f.write(f'\n\n')  

In [None]:
np.savez(res_path / f"{taskinfo}_Time{current_time}", res=res, taskinfo=taskinfo)
tt = np.max(res[:,:,:], axis= -1)

print(f"\n -------\n Final mean {taskinfo} Time: {current_time}, Epochs: {config.num_epochs}, Learning Rate: {config.learning_rate}, "
        f"FMNIST Best: {np.mean(tt[:,1],axis=0) * 100:.2f}%  {np.mean(tt[:,2],axis=0)* 100:.2f}%\n",
        f"MNIST Best: {np.mean(tt[:,4],axis=0) * 100:.2f}%  {np.mean(tt[:,5],axis=0) * 100:.2f}%\n")

