In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import sys
sys.path.append('/root/pygcn/KAN/')
from efficient_kan import KAN


In [2]:
def test_mul():
    kan = KAN([2, 2, 1], base_activation=nn.Identity)
    optimizer = torch.optim.LBFGS(kan.parameters(), lr=0.01)
    with tqdm(range(100)) as pbar:
        for i in pbar:
            loss, reg_loss = None, None

            def closure():
                optimizer.zero_grad()
                x = torch.rand(1024, 2)
                y = kan(x, update_grid=(i % 20 == 0))

                assert y.shape == (1024, 1)
                nonlocal loss, reg_loss
                u = x[:, 0]
                v = x[:, 1]
                loss = nn.functional.mse_loss(y.squeeze(-1), (u + v) / (1 + u * v))
                reg_loss = kan.regularization_loss(1, 0)
                (loss + 1e-5 * reg_loss).backward()
                return loss + reg_loss

            optimizer.step(closure)
            pbar.set_postfix(mse_loss=loss.item(), reg_loss=reg_loss.item())
    for layer in kan.layers:
        print(layer.spline_weight)

In [3]:
test_mul()

100%|██████████| 100/100 [00:13<00:00,  7.47it/s, mse_loss=0.272, reg_loss=0.037]

Parameter containing:
tensor([[[-2.3287e-07, -1.1657e-06,  1.2044e-06,  1.2311e-06,  5.7441e-07,
          -4.7940e-07, -8.2431e-07,  1.0370e-07],
         [ 3.2427e-07, -6.4477e-08, -1.8308e-06,  1.3025e-06,  7.0478e-07,
           6.8023e-07, -5.0191e-07, -1.3001e-07]],

        [[-7.3482e-09, -3.6562e-08, -7.9247e-08,  1.7031e-07,  2.0228e-07,
          -1.4799e-07, -6.6688e-08,  1.2810e-08],
         [ 1.0623e-08,  2.3835e-08, -6.1538e-08,  4.7462e-08,  3.5137e-08,
          -4.4629e-08,  1.3413e-08,  7.5484e-09]]], requires_grad=True)
Parameter containing:
tensor([[[-9.1601e-07, -1.7218e-06,  1.3411e-06,  1.0660e-06, -7.3104e-07,
           4.8550e-07, -4.0431e-07,  3.5600e-08],
         [ 1.5891e-02, -1.5247e-02, -2.6624e-02, -3.4371e-02, -4.7528e-02,
          -5.7160e-02, -4.9101e-02, -4.9666e-02]]], requires_grad=True)



