## Optimization playground

Something to try and optimize a KAN for a mulitvariate function

Importing a few libraries...

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch

In [2]:
class ChebyKANLayer(nn.Module):

    def __init__(self, dim_in, dim_out, order):
        super(ChebyKANLayer, self).__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.order = order

        self.coeffs = nn.Parameter(torch.empty(self.dim_in, self.dim_out, self.order+1))
        nn.init.normal_(self.coeffs, mean=0.0, std=1/(dim_in * (order + 1)))

    def forward(self, x_in):

        x_in = torch.reshape(x_in, (-1, self.dim_in))
        x_in = torch.tanh(x_in)

        bases = torch.ones(x_in.shape[0], self.dim_in, self.order+1, device=x_in.device)

        if self.order > 0:
            bases[:, :, 1] = x_in
        
        for i in range(2, self.order+1):
            bases[:, :, i] = 2 * x_in * bases[:, :, i-1].clone() - bases[:, :, i-2].clone()

        y_out = torch.einsum('bij, ioj -> bo', bases, self.coeffs)
        y_out = y_out.view(-1, self.dim_out)
        return y_out

In [43]:
class KAN(nn.Module):
    def __init__(self):
        super(KAN, self).__init__()
        self.layers = []
        self.best_coeffs = []

    def add_layer(self, layer):
        self.layers.append(layer)
        self.layerMod = nn.ModuleList(self.layers)

    def forward(self, x_in):
        y_pred = self.layers[0].forward(x_in)
        for layer in self.layers[1:]:
            y_pred = layer.forward(y_pred)
        return y_pred

    def train(self, x_in, y_target, learn_rate, epochs, prnt=False):
        loss_fn = nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), learn_rate)
        best_loss = np.inf

        for i in range(epochs):
            optimizer.zero_grad()
            y_pred = self.forward(x_in)
            loss = loss_fn(y_pred, y_target)

            if (loss.detach() < best_loss):
                best_loss = loss.detach()
                self.best_coeffs = [layer.coeffs for layer in self.layers]
            
            loss.backward()
            optimizer.step()

            if prnt and not (i % 100):
                print(f"Epoch: {i}, Loss: {loss}")
            
        for i, layer in enumerate(self.layers):
            layer.coeffs = self.best_coeffs[i]

        if (prnt):
            print(f"Training complete, best loss = {best_loss}")

In [44]:
target_fn = lambda T: torch.sin(torch.pi*(T[..., 0])) * torch.sin(torch.pi*(T[..., 1]))
coord_side = 1000

x1s = torch.linspace(-1, 1, steps=coord_side).unsqueeze(1)
x2s = torch.linspace(-1, 1, steps=coord_side).unsqueeze(1)

x_eval = torch.stack((x1s, x2s), dim=-1)
y_target = target_fn(x_eval.detach().clone())

# test and train split
n_training = int(x_eval.shape[0] * 0.8)
training_idxs = np.random.randint(x_eval.shape[0], size=n_training)
test_idxs = [i for i in range(x_eval.shape[0]) if i not in training_idxs]
x_train = x_eval[training_idxs]
x_test = x_eval[test_idxs]
y_train = y_target[training_idxs]
y_test = y_target[test_idxs]

learn_rate = 0.1

model = KAN()
model.add_layer(ChebyKANLayer(2, 8, 8))
model.add_layer(ChebyKANLayer(8, 1, 8))

model.train(x_train, y_train, learn_rate, 1000, prnt=True)

loss_fn = nn.MSELoss()

x1_check = 0.4
x2_check = 0.3
check_vec = torch.tensor([x1_check, x2_check]).reshape(1, 1, -1)
pred_acc = 0
avg_int = 100
for i in range(avg_int):
    pred_acc+= model.forward(check_vec).detach().item()
pred_acc/=avg_int
print(f"Prediction: {check_pred.detach()}")

Epoch: 0, Loss: 0.401570588350296
Epoch: 100, Loss: 0.00407611345872283
Epoch: 200, Loss: 0.0013517929473891854
Epoch: 300, Loss: 9.917444549500942e-05
Epoch: 400, Loss: 0.0002658458543010056
Epoch: 500, Loss: 3.987685340689495e-05
Epoch: 600, Loss: 2.262530142616015e-05
Epoch: 700, Loss: 0.0003760546096600592
Epoch: 800, Loss: 0.00023832630540709943
Epoch: 900, Loss: 1.3815755664836615e-05
Training complete, best loss = 7.4402573773113545e-06
Prediction: tensor([[0.9774]])
