In [3]:
import torch
import torch.nn as nn
import numpy as np

# Define the univariate functions
class UniVariateFunction(nn.Module):
    def __init__(self, output_size):
        super(UniVariateFunction, self).__init__()
        self.linear = nn.Linear(1, output_size)

    def forward(self, x):
        x = self.linear(x)
        return torch.sin(x) # Using sin as an activation function

# Define the KAN Model
class KAN(nn.Module):
    def __init__(self):
        super(KAN, self).__init__()
        self.phi = nn.ModuleList([UniVariateFunction(1) for _ in range(2)]) # phi functions for x and y
        self.Phi = nn.Linear(2, 1) # Phi function to combine outputs

    def forward(self, x):
        x1, x2 = x[:, 0], x[:, 1]
        x1 = self.phi[0](x1.view(-1, 1))
        x2 = self.phi[1](x2.view(-1, 1))
        out = torch.cat((x1, x2), dim=1)
        out = self.Phi(out)
        return out

# Generate sample data
x = torch.linspace(-np.pi, np.pi, 200)
y = torch.linspace(-np.pi, np.pi, 200)
X, Y = torch.meshgrid(x, y)
Z = torch.sin(X) + torch.cos(Y)

# Prepare inputs and model
inputs = torch.stack([X.flatten(), Y.flatten()], dim=1)
model = KAN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, Z.flatten())
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

Epoch 0, Loss: 1.2395693063735962
Epoch 10, Loss: 1.1339961290359497
Epoch 20, Loss: 1.0631378889083862
Epoch 30, Loss: 1.0264923572540283
Epoch 40, Loss: 1.0109596252441406
Epoch 50, Loss: 1.0048270225524902
Epoch 60, Loss: 1.002313256263733
Epoch 70, Loss: 1.0011882781982422
Epoch 80, Loss: 1.000636100769043
Epoch 90, Loss: 1.000344157218933
