# Jupyter testing

##### Libraries and seed

In [114]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
torch.manual_seed(42)

<torch._C.Generator at 0x7fe7defb3290>

### Architecture
A simple example architecture to get down the basics of PyTorch syntax.

In [115]:
class test_architecture(torch.nn.Module):
    def __init__(self):
        super(test_architecture, self).__init__()
        self.l1_conv = nn.Conv2d(1, 3, 3)
        self.l2_conv = nn.Conv2d(3, 3, 3)
        self.l3_conv = nn.Conv2d(3, 3, 3)
        self.l4_lin = nn.Linear(22, 4)
    
    def forward(self, X):

        X = self.l1_conv(X)
        X = self.l2_conv(X)
        X = self.l3_conv(X)
        X = self.l4_lin(X)
        return X

model = test_architecture()

In [116]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
print(mnist_trainset)
input_data = mnist_trainset.data[0:3].type(torch.float32).reshape(3, 1, 28, 28)
print(input_data)
print(model)
# print(input)
model(input_data)

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])
test_architecture(
  (l1_conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
  (l2_conv): Conv2d

tensor([[[[-3.0528e-01,  4.6293e-01,  6.7150e-01,  2.6270e+00],
          [-8.1140e-01,  4.4309e+00,  5.2215e+00,  1.2427e+01],
          [ 8.4917e+00,  1.3060e+01,  1.3225e+01,  2.6844e+01],
          [ 1.6758e+01,  1.5343e+01,  3.1265e+01,  3.1428e+01],
          [ 3.5510e+01,  3.4429e+00,  4.5722e+01,  2.1526e+01],
          [ 4.0918e+01, -9.1646e+00,  4.9929e+01,  1.1196e+01],
          [ 4.2445e+01,  2.5389e-02,  4.4011e+01, -6.4203e+00],
          [ 3.8544e+01,  1.2340e+01,  3.3843e+01, -1.8951e+01],
          [ 2.5377e+01,  1.4749e+01,  3.2777e+01, -1.9754e+01],
          [ 1.3138e+01,  9.6234e+00,  3.3372e+01, -2.0819e+01],
          [ 2.6933e+00, -1.3757e+01,  3.9112e+01, -1.0594e+01],
          [-3.6848e+00, -3.6477e+01,  3.6255e+01,  5.3784e+00],
          [-2.7813e+00, -4.3146e+01,  2.5471e+01,  2.0044e+01],
          [ 3.6943e+00, -2.7906e+01,  1.1458e+01,  2.5558e+01],
          [ 9.7729e+00, -7.1121e+00,  3.6073e+00,  2.8212e+01],
          [ 1.4368e+01, -3.1465e+00,  5.

Now we do the training loop.

In [117]:
# for i in range(grids.shape[0]):
#     if i == 0:
#         model = KAN(width=[2,1,1], grid=grids[i], k=k, seed=1, device=device)
#     if i != 0:
#         model = model.refine(grids[i])
#     results = model.fit(dataset, opt="adam", steps=steps)
#     train_losses += results['train_loss']
#     test_losses += results['test_loss']

The Kolmogorov-Arnold representation theorem (with splines) as represented in https://arxiv.org/html/2404.19756v1: $$F(x) \approx \sum_{q=0}^{2n} \sum_{p=1}^{n} \phi_{q,p}(x_p),$$ where $\phi_{q,p}$ are (univariate) splines on the domain of $F\colon I \rightarrow \{0,1\}^n$ and $x = (x_1, \ldots , x_p , \ldots x_n) \in I$. The general representation theorem of Kolmogorov-Arnold has an equality, but the  $\phi_{q,p}$ are in general not splines, and in worst case crazy.

Note that the intervall $I$ need to be divided into $G \in \mathbb N$ amount of segments, defining the number of $B$-splines needed in the basis. Idea: Can this be translated to manifolds?