In [2]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import style
style.use('seaborn-whitegrid')

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class MultiRBFnn(nn.Module):
    def __init__(self, in_feature, add_rbf_num, device):
        super(MultiRBFnn, self).__init__()

        self.add_rbf_num = add_rbf_num  # additional RBFs number
        self.in_feature = in_feature    # count features

        self.centers_list = []
        self.sigmas_list = []
        self.weights_list = []

        self.change_th = 3

    def first_rbf_parameter(self, input_data, target): 

        # input_data shape : (data_num)
        # target data shape : (in_feature, 1, data_num)
        
        # first layer centers, weights, sigmas
        # centers, sigmas : (add_rbf_num, 1)
        # weights : (in_feature, add_rbf_num)

        find_index_input = input_data.clone().detach()
        fine_index_target = target.clone().detach()

        find_sigma = target.clone().detach()
        find_weight = target.clone().detach()
        center_index_list = []

        # first MultiRBFs initial centers and weights parameters
        for i in range(self.add_rbf_num):
            index_ = torch.argmax(torch.sum(torch.abs(fine_index_target), dim = 0)).cpu().detach().tolist()
            fine_index_target[:,:,index_] = 0
            center_index_list.append(index_)

        center_index_list = torch.tensor(center_index_list, device=device)
        initcenter = torch.index_select(find_index_input, 0, center_index_list)[-self.add_rbf_num:].reshape(self.add_rbf_num,1)
        initweight = torch.index_select(find_weight, 2, center_index_list)[-self.add_rbf_num:].reshape(self.in_feature, self.add_rbf_num)

        # first MultiRBFs initial sigmas parameters                
        sigma_list = []
        dft = torch.log(torch.abs(torch.fft.fft(find_sigma).real))
        dft =  torch.abs(dft / torch.max(dft)) **-1
        for k in center_index_list:
            sigma_list.append(torch.mean(dft[:,:,k]).reshape(1))

        initsigma = torch.cat(sigma_list)[-self.add_rbf_num:].reshape(self.add_rbf_num, 1)


        return initcenter, initweight, initsigma

    def rbf_gaussian(self, input_data): # no problem
        out = torch.exp(-1 *(torch.pow((input_data - self.centers), 2))) / (torch.pow(self.sigma, 2))

        return out
    
    def forward(self, input_data): # no problem
        R = self.rbf_gaussian(input_data)
        pred = torch.mm(self.weights, R).reshape(self.in_feature, 1, input_data.size(-1))

        return R, pred
    
    def rbf_gradient(self, input_data, C, S, W): # no problem
        rbf_output = (-2 * (input_data - C) / torch.pow(S, 2)) * \
                        (torch.exp(-1 * (torch.pow((input_data - C), 2) / (torch.pow(S, 2)))))
        rbf_grad = torch.mm(W, rbf_output)
        
        return rbf_grad.reshape(self.in_feature, 1, input_data.size(-1)) # (in_feature, 1, data_num)

    def first_grad(self, input_data, target):  # no problem
        space = (input_data,)
        ori_grad = torch.gradient(target, spacing = space, dim = 2, edge_order  = 1)
        return ori_grad[0] # (in_feature, 1, data_num)
        
    def target_grad(self, input_data, centers, sigmas, weights, ori_grad): # no problem
        true_grad = ori_grad - self.rbf_gradient(input_data, centers, sigmas, weights)
         
        return true_grad # (in_feature, 1, data_num)
    
    def predict(self, input_data): # ? 
        rbf_output = torch.exp(-1 * (torch.pow((input_data - self.done_centers), 2) / \
                                     (torch.pow(self.done_sigma, 2))))
        pred = torch.mm(self.done_weights.reshape(self.in_feature, self.add_rbf_num),
                         rbf_output).reshape(self.in_feature, 1, input_data.size(-1))

        return rbf_output, pred
    
    def Loss(self, pred, target, pred_grad, true_grad): # center, sigma 랑 weight loss를 따로 구해야 되나?
        # value loss + gradient loss 

        return torch.mean(torch.pow(target - pred,2) + torch.pow(true_grad - pred_grad, 2))
    

    def L2_F(self, input_data): # 이상함
        return -2 * (input_data - self.centers) / torch.pow(self.sigma, 2) # (add_rbf_num, data_num)

    # partial derivative define

    def L2_2_derivative_weight(self, input_data, radial_output):
        return self.L2_F(input_data) * radial_output               # (add_rbf_num, data_num)

    def rbf_gaussian_derivative_centers(self, input_data): # no problem
        output = (2 * (input_data - self.centers) / (torch.pow(self.sigma, 2))) * self.rbf_gaussian(input_data)

        return output  # size = (add_rbf_num, data_num)
    
    def rbf_gaussian_derivative_sigma(self, input_data): # no problem
        output = (2 * torch.pow((input_data - self.centers), 2) / (torch.pow(self.sigma, 3))) * self.rbf_gaussian(input_data)

        return output  # size = (add_rbf_num, data_num)
    
    # additional RBFs 
    def add_rbf_parameter(self, input_data, error):
        find_index_input = input_data.clone().detach()
        find_index_error = error.clone().detach()
        
        find_weight = error.clone().detach()
        find_sigma = error.clone().detach()
        
        center_index_list = []

        for i in range(self.add_rbf_num * (self.change_time + 1)):
            index_ = torch.argmax(torch.sum(torch.abs(find_index_error), dim = 0)).cpu().detach().tolist()

            find_index_error[:,:,index_] = 0
            center_index_list.append(index_)

        center_index_list = torch.tensor(center_index_list, device=device)
        initcenter = torch.index_select(find_index_input, 0, center_index_list)[-self.add_rbf_num:].reshape(self.add_rbf_num,1)
        initweight = torch.index_select(find_weight, 2, center_index_list)[:,:,-self.add_rbf_num:].reshape(self.in_feature, self.add_rbf_num)


        sigma_list = []
        dft = torch.log(torch.abs(torch.fft.fft(find_sigma).real))
        
        dft = (torch.abs(dft / torch.max(dft))**-1)
        for k in center_index_list:
            sigma_list.append(torch.mean(dft[:,:,k]).reshape(1))
        initsigma = torch.cat(sigma_list)[-self.add_rbf_num:].reshape(self.add_rbf_num,1)

        return initcenter, initweight, initsigma
    
    def change_init(self, na):
        if na == 1:
            loss_list = self.train_loss_list[-self.change_th:]
            if self.number > self.change_th and max(loss_list) == min(loss_list):
                self.change_time += 1
            elif self.number > self.change_th and loss_list[0] < loss_list[1] and loss_list[1] < loss_list[2]:
                self.change_time += 1
            else:
                self.change_time = 0
        else:
            self.change_time += 1

    def best_forward(self, input_data, best_center, best_sigma, best_weight): # ?
        rbf_output = torch.exp(-1 * (torch.pow((input_data - best_center), 2) / \
                                     (torch.pow(best_sigma, 2))))
        pred = torch.mm(best_weight.reshape(self.in_feature, self.add_rbf_num), 
                        rbf_output).reshape(self.in_feature, 1, input_data.size(-1))

        return rbf_output, pred

    def backward_propagation(self, input_data, R, pred, target, target_grad, pred_grad):
        
        L2_1_error = -2 * (target - pred)
        L2_2_error = -2 * (target_grad - pred_grad)

        # updata partial derivative

        deltaSigma1 = torch.mm(self.weights, self.rbf_gaussian_derivative_sigma(input_data))
        #deltaSigma1 = self.rbf_gaussian_derivative_sigma(input_data) * L2_1_error                       # (in_feature, add_rbf_num, data_num)
        #deltaSigma1 *= self.weights.reshape(self.in_feature, self.add_rbf_num, 1)                   # (in_feature, add_rbf_num, data_num)

        deltaSigma2 = self.rbf_gaussian_derivative_sigma(input_data) * L2_2_error                       # (in_feature, add_rbf_num, data_num)
        deltaSigma2 *= self.L2_F(input_data) * self.weights.reshape(self.in_feature, self.add_rbf_num, 1)    # (in_feature, add_rbf_num, data_num)

        deltaSigma =  torch.sum(torch.sum(deltaSigma1, dim=2), dim = 0) + torch.sum(torch.sum(deltaSigma2, dim=2), dim = 0) # (add_rbf_num) 

        # center partial derivative
        deltaCenter1 = self.rbf_gaussian_derivative_centers(input_data) * L2_1_error
        deltaCenter1 *= self.weights.reshape(self.in_feature, self.add_rbf_num, 1)
        
        deltaCenter2 = self.rbf_gaussian_derivative_centers(input_data) * L2_2_error
        deltaCenter2 *= self.L2_F(input_data) * self.weights.reshape(self.in_feature, self.add_rbf_num, 1)

        deltaCenter =  torch.sum(torch.sum(deltaCenter1, dim=2),dim =0) + torch.sum(torch.sum(deltaCenter2, dim=2), dim = 0) # (add_rbf_num)


        # weight partial derivative
        delta_weight1 = torch.sum((R * L2_1_error), dim=2)        # (in_feature, add_rbf_num)
        delta_weight2 = torch.sum((self.L2_2_derivative_weight(input_data, R) * L2_2_error), dim = 2) # (in_feature, add_rbf_num)
        delta_weight = delta_weight1 + delta_weight2 # (in_feature, add_rbf_num)

        # BP update
        self.weights -= self.lr * delta_weight
        self.centers -= self.lr * deltaCenter.reshape(self.add_rbf_num, 1)
        self.sigma -= self.lr * deltaSigma.reshape(self.add_rbf_num, 1)

    def plot_train(self, input_data, best_pred): #done
        fig, ax = plt.subplots(1, 3, figsize = (30, 5))
        for i in range(self.in_feature):
            ax[i].plot(input_data.cpu().detach().numpy(), self.target[i][0].cpu().detach().numpy())
            ax[i].plot(input_data.cpu().detach().numpy(), best_pred[i][0].cpu().detach().numpy())
        plt.show()

    def train(self, input_data, target, epochs, lr, loss_th):
        self.lr = lr
        self.target = target.clone().detach()
        self.number = 0
        self.train_loss_list = []
        self.loss_th = loss_th
        self.change_time = 0

        break_time = len(input_data) / self.add_rbf_num
        # count_loss_chage = 0
        # count_round_change = 0

        loss = 100000
        
        while self.loss_th < loss:

            print("{}th additional rbflayer".format(self.number))
            # first rbflayer
            if self.number == 0:
                self.centers, self.weights, self.sigma = self.first_rbf_parameter(input_data, self.target)
                first_grad = self.first_grad(input_data, target)

                for epoch in range(epochs):
                    R, pred = self.forward(input_data)
                    rbf_grad = self.rbf_gradient(input_data, self.centers, self.sigma, self.weights)

                    self.backward_propagation(input_data, R, pred, self.target, first_grad, rbf_grad)
                    epoch_loss = self.Loss(pred, self.target, rbf_grad, first_grad)

                    if epoch == 0:
                        print("{}th additional RBFlayer {}th epoch loss: {}".format(self.number, epoch, epoch_loss))
                        self.best_loss = epoch_loss.clone().detach()
                        self.best_center = self.centers.clone().detach()
                        self.best_sigma = self.sigma.clone().detach()
                        self.best_weight = self.weights.clone().detach()
                    
                    else:
                        if self.best_loss > epoch_loss:
                            self.best_loss = epoch_loss.clone().detach()
                            self.best_center = self.centers.clone().detach()
                            self.best_sigma = self.sigma.clone().detach()
                            self.best_weight = self.weights.clone().detach()

                    if (epoch + 1) % 250 == 0:
                        print("{}th additional RBFlayer {}th epoch MSE Loss: {}".format(self.number, epoch, epoch_loss))

            else:
                self.change_init(na)
                if self.change_time > break_time:
                    break
                
                self.centers, self.weights, self.sigma = self.add_rbf_parameter(input_data, self.target)

                for epoch in range(epochs):
                    R, pred = self.forward(input_data)
                    rbf_grad = self.rbf_gradient(input_data, self.centers, self.sigma, self.weights)

                    if epoch == 0:
                        print("{}th additional RBFlayer {}th epoch loss: {}".format(self.number, epoch,
                                                                                     self.Loss(pred, self.target, rbf_grad, target_grad)))
                        self.best_loss = self.Loss(pred, self.target, rbf_grad, target_grad).clone().detach()
                        self.best_center = self.centers.clone().detach()
                        self.best_sigma = self.sigma.clone().detach()
                        self.best_weight = self.weights.clone().detach()

                    self.backward_propagation(input_data, R, pred, self.target, first_grad, rbf_grad)
                    epoch_loss = self.Loss(pred, self.target, rbf_grad, target_grad)

                    if (epoch + 1) % 250 == 0:
                        print("{}th additional RBFlayer {}th epoch MSE Loss: {}".format(self.number, epoch, epoch_loss))
                    
                    if self.best_loss > epoch_loss:
                        self.best_loss = epoch_loss.clone().detach()
                        self.best_center = self.centers.clone().detach()
                        self.best_sigma = self.sigma.clone().detach()
                        self.best_weight = self.weights.clone().detach()
                
            best_R, best_pred = self.best_forward(input_data, self.best_center, self.best_sigma, self.best_weight)
            best_grad = self.rbf_gradient(input_data, self.best_center, self.best_sigma, self.best_weight)

            if self.number ==0:
                train_loss = self.Loss(best_pred, self.target, best_grad, first_grad)
            else:
                train_loss = self.Loss(best_pred, self.target, best_grad, target_grad)

            print("{}th additional RBFlayer best loss : ".format(self.number, train_loss))
            self.train_loss_list.append(train_loss)

            # additional rbf plot print
            self.plot_train(input_data, best_pred)

            if torch.isnan(train_loss) == False:
                na = 1
                self.target = self.target - best_pred  # target update
                loss = train_loss  # loss update
                self.number += 1  # additional rbf number update
                self.centers_list.append(self.best_center)
                self.sigmas_list.append(self.best_sigma)
                self.weights_list.append(self.best_weight)

                self.done_centers = torch.cat(self.centers_list, dim  =0)
                self.done_sigma = torch.cat(self.sigmas_list, dim = 0)
                self.done_weights = torch.cat(self.weights_list, dim = 1)
                target_grad = self.target_grad(input_data, self.done_centers, self.done_sigma, self.done_weights, first_grad)
            else:
                na = 0

In [7]:
import numpy as np
a = np.arange(0,1,0.1)

input_ = torch.tensor(a, device = device)

In [8]:
target = torch.rand((3,1,10), device = device)

In [9]:
target

tensor([[[0.6720, 0.7563, 0.3495, 0.7097, 0.9419, 0.3091, 0.9351, 0.2300,
          0.1129, 0.6612]],

        [[0.8131, 0.0613, 0.8664, 0.2892, 0.3724, 0.7141, 0.5831, 0.7052,
          0.1583, 0.8392]],

        [[0.5025, 0.8348, 0.3649, 0.7155, 0.8703, 0.1943, 0.7555, 0.3257,
          0.0147, 0.5034]]], device='cuda:0')

In [19]:
# input_data shape : (data_num)
# target data shape : (in_feature, 1, data_num)

# first layer centers, weights, sigmas
# centers, sigmas : (add_rbf_num, 1)
# weights : (in_feature, add_rbf_num)

centers = torch.tensor([0.3, 0.4], device = device).reshape(2,1)
sigma = torch.tensor([0.2,0.3], device = device).reshape(2,1)
weights = torch.tensor([[0.5, .3], [0.2, .1], [0.7, .3]], device = device, dtype= float).reshape(3, 2)

In [21]:
def rbf_gaussian(input_data):
    out = torch.exp(-1 *(torch.pow((input_data - centers), 2))) / (torch.pow(sigma, 2))

    return out # (add_rbf_num, data_len)

def forward( input_data):
    R = rbf_gaussian(input_data)
    pred = torch.mm(weights, R).reshape(3, 1, input_data.size(-1))

    return R, pred 

In [15]:
rbf_gaussian(input_)

tensor([[22.8483, 24.0197, 24.7512, 25.0000, 24.7512, 24.0197, 22.8483, 21.3036,
         19.4700, 17.4419],
        [ 9.4683, 10.1548, 10.6754, 11.0006, 11.1111, 11.0006, 10.6754, 10.1548,
          9.4683,  8.6533]], device='cuda:0', dtype=torch.float64)

In [22]:
forward(input_)

(tensor([[22.8483, 24.0197, 24.7512, 25.0000, 24.7512, 24.0197, 22.8483, 21.3036,
          19.4700, 17.4419],
         [ 9.4683, 10.1548, 10.6754, 11.0006, 11.1111, 11.0006, 10.6754, 10.1548,
           9.4683,  8.6533]], device='cuda:0', dtype=torch.float64),
 tensor([[[14.2646, 15.0563, 15.5783, 15.8002, 15.7090, 15.3100, 14.6268,
           13.6982, 12.5755, 11.3170]],
 
         [[ 5.5165,  5.8194,  6.0178,  6.1001,  6.0614,  5.9040,  5.6372,
            5.2762,  4.8408,  4.3537]],
 
         [[18.8343, 19.8603, 20.5285, 20.8002, 20.6592, 20.1140, 19.1964,
           17.9590, 16.4695, 14.8053]]], device='cuda:0', dtype=torch.float64))

In [30]:
def first_grad(input_data, target): 
    space = (input_data,)
    ori_grad = torch.gradient(target, spacing = space, dim = 1, edge_order  = 1)
    return ori_grad[0] # (in_feature, 1, data_num)

In [28]:
first_grad(input_, target)

tensor([[[ 0.8421, -1.6129, -0.2327,  2.9619, -2.0032, -0.0336, -0.3955,
          -4.1111,  2.1564,  5.4832]],

        [[-7.5183,  0.2667,  1.1395, -2.4703,  2.1248,  1.0539, -0.0447,
          -2.1241,  0.6702,  6.8092]],

        [[ 3.3230, -0.6880, -0.5962,  2.5270, -2.6059, -0.5739,  0.6570,
          -3.7041,  0.8886,  4.8876]]], device='cuda:0', dtype=torch.float64)

In [33]:
first_grad(input_, target[2])

tensor([[ 3.3230, -0.6880, -0.5962,  2.5270, -2.6059, -0.5739,  0.6570, -3.7041,
          0.8886,  4.8876]], device='cuda:0', dtype=torch.float64)

In [34]:
def rbf_gaussian_derivative_sigma(input_data): # no problem
    output = (2 * torch.pow((input_data - centers), 2) / (torch.pow(sigma, 3))) * rbf_gaussian(input_data)

    return output  # size = (add_rbf_num, data_num)

In [35]:
S_der = rbf_gaussian_derivative_sigma(input_)

In [36]:
S_der.size()

torch.Size([2, 10])

In [38]:
weights.size()

torch.Size([3, 2])

In [42]:
torch.mm(weights, S_der)

tensor([[2.9071e+02, 1.4041e+02, 4.0428e+01, 2.4446e+00, 3.0939e+01, 1.2254e+02,
         2.6653e+02, 4.4638e+02, 6.4210e+02, 8.3296e+02],
        [1.1404e+02, 5.4809e+01, 1.5539e+01, 8.1486e-01, 1.2376e+01, 4.8854e+01,
         1.0598e+02, 1.7720e+02, 2.5460e+02, 3.2998e+02],
        [3.9353e+02, 1.8845e+02, 5.2804e+01, 2.4446e+00, 4.3315e+01, 1.7058e+02,
         3.6935e+02, 6.1681e+02, 8.8548e+02, 1.1469e+03]], device='cuda:0',
       dtype=torch.float64)