Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to apply KAN on Computer Vision #9

Open
WuZhuoran opened this issue May 1, 2024 · 39 comments
Open

How to apply KAN on Computer Vision #9

WuZhuoran opened this issue May 1, 2024 · 39 comments

Comments

@WuZhuoran
Copy link

Hi Author,

Thank you for your great work. I am wondering if we can apply this network on Vision based task such as classification/detection/segmentation, etc.

Thank you for your help.

@WuZhuoran
Copy link
Author

WuZhuoran commented May 1, 2024

Update on this topic:

I write a short notebook to test traiing and evaluation on MNIST dataset. And If we want to apply KAN on 2D or 3D task, One possible way is to change KANlayer inherits nn.Conv2d?

Here is the screen shot:

image

And here is the Traceback:

description:   0%|                                                           | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 1
----> 1 results = model.train(dataset, opt="LBFGS", steps=20, loss_fn=torch.nn.CrossEntropyLoss());

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
    910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
    912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913     self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
    916 if opt == "LBFGS":
    917     optimizer.step(closure)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:242, in KAN.update_grid_from_samples(self, x)
    219 '''
    220 update grid from samples
    221 
   (...)
    239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
    240 '''
    241 for l in range(self.depth):
--> 242     self.forward(x)
    243     self.act_fun[l].update_grid_from_samples(self.acts[l])

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:313, in KAN.forward(self, x)
    308 self.acts.append(x) # acts shape: (batch, width[l])
    311 for l in range(self.depth):
--> 313     x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
    315     if self.symbolic_enabled == True:
    316         x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KANLayer.py:172, in KANLayer.forward(self, x)
    170 batch = x.shape[0]
    171 # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
--> 172 x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim,).to(self.device)).reshape(batch, self.size).permute(1,0)
    173 preacts = x.permute(1,0).clone().reshape(batch, self.out_dim, self.in_dim)
    174 base = self.base_fun(x).permute(1,0) # shape (batch, size)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/functional.py:385, in einsum(*args)
    380     return einsum(equation, *_operands)
    382 if len(operands) <= 2 or not opt_einsum.enabled:
    383     # the path for contracting 0 or 1 time(s) is already optimized
    384     # or the user has disabled using opt_einsum
--> 385     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    387 path = None
    388 if opt_einsum.is_available():

RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

@KindXiaoming
Copy link
Owner

yeah I think KANs, as they are right now, cannot handle convolution. It seems reasonable to defineConvKAN layers. Given the current implementation, the only thing you can do with vision tasks is flattening a whole image into a vector, totally abandoning spatial information (which is not good, that's why I think extra development is needed).

@KindXiaoming
Copy link
Owner

KindXiaoming commented May 1, 2024

As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST.

Please make sure input data have shape [data size, indim], indim=784. Also, the input dimension of KAN should be 784, and output should be 10. So e.g., these KANs are valid for MNIST: KAN(width=[784,5,10]) or KAN(width=[784,5,5,10]). Also you may want to include say batch=128 in model.train() to train on batches rather than the whole dataset (which is fine, but I worry it might run too slowly on cpu haha).

@WuZhuoran
Copy link
Author

WuZhuoran commented May 1, 2024

As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST.

Please make sure input data have shape [data size, indim], indim=784. Also, the input dimension of KAN should be 784, and output should be 10. So e.g., these KANs are valid for MNIST: KAN(width=[784,5,10]) or KAN(width=[784,5,5,10]). Also you may want to include say batch=128 in model.train() to train on batches rather than the whole dataset (which is fine, but I worry it might run too slowly on cpu haha).

Thank for the quick reply.

It did work with

model = KAN(width=[784,5,5,10], grid=3, k=3).to(device)

and

dataset['train_input'] = torch.flatten(train_dataset.data, start_dim=1).to(device)
dataset['test_input'] = torch.flatten(test_dataset.data, start_dim=1).to(device)

Now training can work on device cpu (slow as expected). But it will raise error when using Apple Chip with device mps

results = model.train(dataset, opt="LBFGS", steps=20, loss_fn=torch.nn.CrossEntropyLoss(), batch=128, device='cpu');

But anyway, faltten the image into 1D is not a good idea in general. A VisionKAN or KAN_Conv2d need to be implemented. LOL.

@KindXiaoming
Copy link
Owner

Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-)

@WuZhuoran
Copy link
Author

Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-)

Yeah. about GPU Training. I might need to use CUDA first. For MPS, it will raised this error for the Classificaion Example:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

I already make datasets and model on mps device:

dataset['train_input'] = torch.from_numpy(train_input).to(torch.float32).to(device)
dataset['test_input'] = torch.from_numpy(test_input).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(train_label[:,None]).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(test_label[:,None]).to(torch.float32).to(device)

model = KAN(width=[2,1], grid=3, k=3).to(torch.float32).to(device)

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), device=device);

The full traceback is:

description:   0%|                                                           | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 9
      6 def test_acc():
      7     return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
----> 9 results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), device=device);
     10 results['train_acc'][-1], results['test_acc'][-1]

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
    910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
    912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913     self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
    916 if opt == "LBFGS":
    917     optimizer.step(closure)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:242, in KAN.update_grid_from_samples(self, x)
    219 '''
    220 update grid from samples
    221 
   (...)
    239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
    240 '''
    241 for l in range(self.depth):
--> 242     self.forward(x)
    243     self.act_fun[l].update_grid_from_samples(self.acts[l])

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:313, in KAN.forward(self, x)
    308 self.acts.append(x) # acts shape: (batch, width[l])
    311 for l in range(self.depth):
--> 313     x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
    315     if self.symbolic_enabled == True:
    316         x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KANLayer.py:172, in KANLayer.forward(self, x)
    170 batch = x.shape[0]
    171 # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
--> 172 x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim,).to(self.device)).reshape(batch, self.size).permute(1,0)
    173 preacts = x.permute(1,0).clone().reshape(batch, self.out_dim, self.in_dim)
    174 base = self.base_fun(x).permute(1,0) # shape (batch, size)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/functional.py:385, in einsum(*args)
    380     return einsum(equation, *_operands)
    382 if len(operands) <= 2 or not opt_einsum.enabled:
    383     # the path for contracting 0 or 1 time(s) is already optimized
    384     # or the user has disabled using opt_einsum
--> 385     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    387 path = None
    388 if opt_einsum.is_available():

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

@cpellet
Copy link

cpellet commented May 2, 2024

It turns out that replacing model = KAN(width=[784,5,5,10], grid=3, k=3).to(device) by model = KAN(width=[784,5,5,10], grid=3, k=3, device=device) does the trick for me! Here is a full example training on mps for reference:

from kan import *
from tensorflow import keras

device = "mps"
model = KAN(width=[7*7, 5, 5, 128], grid=3, k=3, device=device)

(X_train,y_train),(X_test,y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

# downsample to 7x7
X_train = np.array([cv2.resize(x, (7,7)) for x in X_train])
X_test = np.array([cv2.resize(x, (7,7)) for x in X_test])

dataset = {}
dataset['train_input'] = torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset['test_input'] = torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(y_test).to(torch.float32).to(device)

model.train(dataset, opt="LBFGS", steps=20, batch=128)

@noahvandal
Copy link

Were you able to actually train on MNIST using a flat dataset?

@cpellet
Copy link

cpellet commented May 2, 2024

No unfortunately, @WuZhuoran's comments make sense, I was solely making a point about getting training on mps to work

@MeDenTec
Copy link

MeDenTec commented May 2, 2024

Hi everybody, please let me know if anybody of you successfully applied KANs to any Computer vision tasks? or anybody integrated it with CNNs ?

@MeDenTec
Copy link

MeDenTec commented May 2, 2024

Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code.

@genglinxiao
Copy link

Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture.

@MiXaiLL76
Copy link

I tried something like this, but it didn’t work
loss decreases slowly

import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt

train_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=False, download=True, transform=None
)

valid_labels = [0, 1, 2]

X_train = []
y_train = []

for pil_img, label in train_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_train.append(x.astype(float))
        y_train.append(label)

X_train = np.array(X_train)
y_train = np.array(y_train)

mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")

X_test = []
y_test = []
for pil_img, label in test_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_test.append(x.astype(float))
        y_test.append(label)

X_test = np.array(X_test)
y_test = np.array(y_test)

X_test = (X_test - mean) / std
X_train = (X_train - mean) / std

device = "cpu"
model = KAN(width=[x.shape[0]**2, 20, 20, len(valid_labels)], grid=3, k=3, device=device)

dataset = {}
dataset["train_input"] = (
    torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset["test_input"] = (
    torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).to(torch.float32).to(device)

result = model.train(dataset, opt="Adam", steps=100, lr=0.1, batch=len(valid_labels), device=device)

plt.plot(result['train_loss'], label="train_loss")
plt.plot(result['test_loss'], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.show()

@KindXiaoming
Copy link
Owner

KindXiaoming commented May 2, 2024

My experience with MNIST is that a 2-Layer KAN with an extremely small (say 5 or 10) hidden neurons is enough to train MNIST (but maybe my impression was from accuracy), i.e., KAN(width=[49, 10, 3]) in your case. It's likely that accuracies are high but losses are high.

So please try computing acc as well. You can refer to this tutorial to see how to do this. Basically, it's something like

def train_acc():
    return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())

def test_acc():
    return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
results['train_acc'][-1], results['test_acc'][-1]

@MeDenTec
Copy link

MeDenTec commented May 3, 2024

Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture.

I am also willing to do so, but don't know how to integrate and train simultaneously

@GorkaAbad
Copy link

GorkaAbad commented May 3, 2024

Hi,
here is a working MNIST example using CUDA. Reusing some code from above. It may be verbose and far from optimal.

I get around 73% test accuracy in about 1 minute. Playing with the network size may improve the performance.

import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt


def train_acc():
    # model for some reason is on cpu only here, something about KAN's implementation
    try:
        arg = (
            torch.argmax(model(dataset["train_input"]), dim=1) == dataset["train_label"]
        )

    except:
        arg = torch.argmax(model(dataset["train_input"].to("cpu")), dim=1) == dataset[
            "train_label"
        ].to("cpu")
    return torch.mean(arg.float())


def test_acc():
    try:
        arg = torch.argmax(model(dataset["test_input"]), dim=1) == dataset["test_label"]
    except:
        arg = torch.argmax(model(dataset["test_input"].to("cpu")), dim=1) == dataset[
            "test_label"
        ].to("cpu")

    return torch.mean(arg.float())


train_data = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=None
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
valid_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

X_train = []
y_train = []

for pil_img, label in train_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_train.append(x.astype(float))
        y_train.append(label)

X_train = np.array(X_train)
y_train = np.array(y_train)

mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")

X_test = []
y_test = []
for pil_img, label in test_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_test.append(x.astype(float))
        y_test.append(label)

X_test = np.array(X_test)
y_test = np.array(y_test)

X_test = (X_test - mean) / std
X_train = (X_train - mean) / std


model = KAN(width=[x.shape[0] ** 2, 20, len(valid_labels)], grid=5, k=3, device=device)

dataset = {}
dataset["train_input"] = (
    torch.flatten(torch.from_numpy(X_train), start_dim=1).long().to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).long().to(device)

dataset["test_input"] = (
    torch.flatten(torch.from_numpy(X_test), start_dim=1).long().to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).long().to(device)

loss_fn = torch.nn.CrossEntropyLoss()

result = model.train(
    dataset,
    opt="Adam",
    steps=50,
    lr=0.1,
    batch=512,
    # metrics=(
    #     train_acc,
    #     test_acc,
    # ),  # this is the slower step, so its better to evaluate it after training
    loss_fn=loss_fn,
    # device=device,
)

acc = test_acc()
print(f"Test accuracy: {acc.item()}")


plt.plot(result["train_loss"], label="train_loss")
plt.plot(result["test_loss"], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.savefig("loss.png")

@Menghuan1918
Copy link

Hi, here's my attempted code, its going to take about 30s to run on CUDA and get about 83% accuracy.

import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt

def preprocess_data(data):
    images = []
    labels = []
    for img, label in data:
        img = cv2.resize(np.array(img), (7, 7))
        img = img.flatten() / 255.0
        images.append(img)
        labels.append(label)
    return np.array(images), np.array(labels)

train_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=False, download=True, transform=None
)

train_images, train_labels = preprocess_data(train_data)
test_images, test_labels = preprocess_data(test_data)

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using {device} device")

dataset = {
    "train_input": torch.from_numpy(train_images).float().to(device),
    "train_label": torch.from_numpy(train_labels).to(device),
    "test_input": torch.from_numpy(test_images).float().to("cpu"),
    "test_label": torch.from_numpy(test_labels).to("cpu"),
}

model = KAN(width=[49, 10, 10], device=device)

results = model.train(
    dataset,
    opt="Adam",
    lr=0.05,
    steps=100,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
)
torch.save(model.state_dict(), "kan.pth")


del model
model = KAN(width=[49, 10, 10], device="cpu")
model.load_state_dict(torch.load("kan.pth"))

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()
plt.savefig("kan.png")

Output

@Fredrik00
Copy link

I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance.

@Shomvel
Copy link

Shomvel commented May 3, 2024

Compared KAN with 5x smaller MLP. In 10 epochs, KAN reached 91% acc whereas MLP reached 97%. KAN loss goes down more slowly than that of MLP.

# implementation from https://github.com/Blealtan/efficient-kan
class EKAN(nn.Module):
    pass

# a simple MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.layers(x)

# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors and scale to [0,1]
    transforms.Normalize((0.5,), (0.5,))  # Normalize to mean=0.5, std=0.5
])

# Load the datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model, loss function, and optimizer
model = EKAN([28*28, 64, 10]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training the model
def train_model(num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            images = images.view(images.shape[0], -1)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

train_model(10)

# Testing the model
def test_model():
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            images = images.view(images.shape[0], -1)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%')

test_model()

@WuZhuoran
Copy link
Author

Were you able to actually train on MNIST using a flat dataset?

Hi,

I did train on MNIST dataset but it is just flatten the image into 1D vector. I think we still need more development on Computer Vision Task.

@WuZhuoran
Copy link
Author

I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance.

Good point on 2D conv layer. One possible is to define a kan_conv2d layer. then we can build KAN3D or KAN directly with different conv2d layer.

Currently all the test on images (such as MNIST) that are processing the data into 1D vector which is not very useful.

@zdx3578
Copy link

zdx3578 commented May 4, 2024

before use conv2d, what about use VAE latent space ,train KAN MNIST use VAE encode output as KAN input ?

@xiaol
Copy link

xiaol commented May 4, 2024

According to my experiments, the modified version of Kan outperformed MLP with the same shape on the MNIST dataset , both 768 64 10, using the efficient kan code above with some tweaks.
this is kan+
image

this is mlp
W62P J A%@(1OLSLP}Y1G1T

@zdx3578
Copy link

zdx3578 commented May 4, 2024

@xiaol use Handwritten Sequence Trajectories?

@juntaoJianggavin
Copy link

Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code.

I tried to replace MLP with KAN in CNN models, and the performances are close to each other.

https://github.com/juntaoJianggavin/kan-cifar10/tree/main

@Uljibuh
Copy link

Uljibuh commented May 6, 2024

how can i build a Conv-KAN ? how do i integrate convolotion into KAN ?

@paulestano
Copy link

paulestano commented May 7, 2024

I used a 'linearized version' of nn.Conv2d using nn.Unfold and a reshape to build a KANConv2d I'm not completely sure whether it makes sense and I don't think it's efficient at all but you may check it out

@SimoSbara
Copy link

I also tried a simple implementation of LeNet but with KAN as classifier: https://github.com/SimoSbara/kan-lenet

KAN receives flatten data from convolution.

@HaiFengZeng
Copy link

I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:

import torchvision
import torch
from torchvision import transforms 
import torch.nn as nn
import torch.nn.functional as F
from kan import KAN
import tqdm
transform = transforms.Compose(
    [transforms.ToTensor(),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    ]
    )

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=500,
                                         shuffle=False, num_workers=2)
print(len(trainset),len(testset))
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28,64).cuda()
        self.kan = KAN(width=[64,16,10], grid=5, k=3, seed=0,device='cuda:0')
    
    def forward(self,x):
        x = self.linear(x)
        out = self.kan(x)
        return out


net = Net().cuda()

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.002,)

for epoch in range(4):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # print('predict.size=',pred.size())
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        x = inputs.view(inputs.size(0),-1).cuda()
        outputs = net(x)
        loss = criterion(outputs, labels.cuda())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i %100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
    correct = 0
    total = 0
    # net.eval()
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            # calculate outputs by running images through the network
            x = inputs.view(inputs.size(0),-1).cuda()
            outputs = net(x)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum().item()

        print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct // total} %')
    # net.train()
print('Finished Training')

After 4 epochs training, acc comes to 96%, the logs looks like:

60000 10000
99it [00:52,  1.90it/s][1,   100] loss: 0.028
120it [01:03,  1.89it/s]
[1,   120] loss: 0.002
epoch 0 Accuracy of the network on the 10000 test images: 93 %
99it [00:51,  1.95it/s][2,   100] loss: 0.010
120it [01:02,  1.91it/s]
[2,   120] loss: 0.002
epoch 1 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][3,   100] loss: 0.006
120it [01:02,  1.93it/s]
[3,   120] loss: 0.001
epoch 2 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][4,   100] loss: 0.005
120it [01:02,  1.93it/s]
[4,   120] loss: 0.001
epoch 3 Accuracy of the network on the 10000 test images: 96 %
Finished Training

@SimoSbara
Copy link

I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:

import torchvision
import torch
from torchvision import transforms 
import torch.nn as nn
import torch.nn.functional as F
from kan import KAN
import tqdm
transform = transforms.Compose(
    [transforms.ToTensor(),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    ]
    )

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=500,
                                         shuffle=False, num_workers=2)
print(len(trainset),len(testset))
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28,64).cuda()
        self.kan = KAN(width=[64,16,10], grid=5, k=3, seed=0,device='cuda:0')
    
    def forward(self,x):
        x = self.linear(x)
        out = self.kan(x)
        return out


net = Net().cuda()

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.002,)

for epoch in range(4):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # print('predict.size=',pred.size())
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        x = inputs.view(inputs.size(0),-1).cuda()
        outputs = net(x)
        loss = criterion(outputs, labels.cuda())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i %100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
    correct = 0
    total = 0
    # net.eval()
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            # calculate outputs by running images through the network
            x = inputs.view(inputs.size(0),-1).cuda()
            outputs = net(x)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum().item()

        print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct // total} %')
    # net.train()
print('Finished Training')

After 4 epochs training, acc comes to 96%, the logs looks like:

60000 10000
99it [00:52,  1.90it/s][1,   100] loss: 0.028
120it [01:03,  1.89it/s]
[1,   120] loss: 0.002
epoch 0 Accuracy of the network on the 10000 test images: 93 %
99it [00:51,  1.95it/s][2,   100] loss: 0.010
120it [01:02,  1.91it/s]
[2,   120] loss: 0.002
epoch 1 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][3,   100] loss: 0.006
120it [01:02,  1.93it/s]
[3,   120] loss: 0.001
epoch 2 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][4,   100] loss: 0.005
120it [01:02,  1.93it/s]
[4,   120] loss: 0.001
epoch 3 Accuracy of the network on the 10000 test images: 96 %
Finished Training

In comparison with MLP its a good improvement.
Although in real cases the convolution gives real robustness in OCR applications.

It would be nice to have a peformance benchmark for bigger nets where kan replaces mlp.

@hesamsheikh
Copy link

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image

Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

@GeorgeDeac
Copy link

GeorgeDeac commented May 11, 2024

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image

Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

Maybe the distribution of the data in the flattened vector is harder to be represented with splines vs. the universal approximation theorem (the perceptron). I would imagine that flattening an image to a single vector could give very sudden and local differences across instances, aka smaller granularity which might be inherently harder to represent with splines?

Edit:
Here's is the distribution of the input data we are actually trying to learn from that flattened vector

image

Which corresponds to this heatmap in the non-flattened image:

image

And these are the ranges of pixel intensities:

image

So yeah, there's a lot of sudden jumps

@hesamsheikh
Copy link

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image
Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

Maybe the distribution of the data in the flattened vector is harder to be represented with splines vs. the universal approximation theorem (the perceptron). I would imagine that flattening an image to a single vector could give very sudden and local differences across instances, aka smaller granularity which might be inherently harder to represent with splines?

Edit: Here's is the distribution of the input data we are actually trying to learn from that flattened vector

image

Which corresponds to this heatmap in the non-flattened image:

image

And these are the ranges of pixel intensities:

image

So yeah, there's a lot of sudden jumps

I was expecting much more resistance to sudden local jumps from splines. this is what i would infer from the continual learning section of the paper as splines preserve local information much more than MLPs.

image

I guess @KindXiaoming would have some idea about this.

@GeorgeDeac
Copy link

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image
Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

Maybe the distribution of the data in the flattened vector is harder to be represented with splines vs. the universal approximation theorem (the perceptron). I would imagine that flattening an image to a single vector could give very sudden and local differences across instances, aka smaller granularity which might be inherently harder to represent with splines?
Edit: Here's is the distribution of the input data we are actually trying to learn from that flattened vector
image
Which corresponds to this heatmap in the non-flattened image:
image
And these are the ranges of pixel intensities:
image
So yeah, there's a lot of sudden jumps

I was expecting much more resistance to sudden local jumps from splines. this is what i would infer from the continual learning section of the paper as splines preserve local information much more than MLPs.

image

I guess @KindXiaoming would have some idea about this.

I would just guess that there might be a representation power limit given by how small and sudden the inflexions are, given the number of parameters we have for the splines? I would also like to investigate the actual reason tbh

@AlexBodner
Copy link

We implemented the KAN Convolutional Layers, check out our repo based in the efficient-kan implementation:
https://github.com/AntonioTepsich/Convolutional-KANs

@hesamsheikh
Copy link

We implemented the KAN Convolutional Layers, check out our repo based in the efficient-kan implementation: https://github.com/AntonioTepsich/Convolutional-KANs

your results also point out in the case of MNIST KAN isn't able to scale as much as promised in the paper, essentially being in the same level of MLP regarding parameters.

@GeorgeDeac
Copy link

GeorgeDeac commented May 15, 2024

Probably, after all, the representation power of KANs depends a lot on the distribution and shape of the data. Spline representation imposes some constraints to some shapes making them harder to represent, in contrast to MLPs which don't care that much about extreme shapes. There are still cases where I think KANs consistently outperform MLPs, but I guess it depends a lot on the data domain we are dealing with.

I also saw some implementations that use RBFs instead of splines for KANs, I imagine that RBFs are kinda similar, they would be better compared to MLPs if our data contains gaussian shapes and has some normality.
https://github.com/ZiyaoLi/fast-kan

Also saw many more, that replace the kernel from the standard KAN with different polynomial representations (like Chebyshev) or even saw a wavelet kernel:
https://github.com/SynodicMonth/ChebyKAN
https://github.com/mlsquare/xKAN

But at the end of the day, I think all of these are biased towards better representing certain domains / data shapes and might not universally scale in all cases (depending on the data).

It would be beneficial to make a synthetic data benchmark with some examples of consistent extreme shapes / gradients or other edge cases, and test all these architectures against MLPs.

@XiangboGaoBarry
Copy link

Hi, here I implement ConvKAN with different activation formulations with their corresponding inference time. https://github.com/XiangboGaoBarry/ConvKAN-Zoo
We evaluate the result on CIFAR10 dataset.

@tommarvoloriddle
Copy link

We are trying to use KAN on ViT to replace MLP for training on ImageNet, and we welcome co-builds!
Vision-KAN

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests