-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to apply KAN on Computer Vision #9
Comments
Update on this topic: I write a short notebook to test traiing and evaluation on MNIST dataset. And If we want to apply KAN on 2D or 3D task, One possible way is to change KANlayer inherits Here is the screen shot: And here is the Traceback:
|
yeah I think KANs, as they are right now, cannot handle convolution. It seems reasonable to define |
As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST. Please make sure input data have shape |
Thank for the quick reply. It did work with
and
Now training can work on device cpu (slow as expected). But it will raise error when using Apple Chip with device
But anyway, faltten the image into 1D is not a good idea in general. A |
Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-) |
Yeah. about GPU Training. I might need to use CUDA first. For MPS, it will raised this error for the Classificaion Example:
I already make datasets and model on
The full traceback is:
|
It turns out that replacing from kan import *
from tensorflow import keras
device = "mps"
model = KAN(width=[7*7, 5, 5, 128], grid=3, k=3, device=device)
(X_train,y_train),(X_test,y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
# downsample to 7x7
X_train = np.array([cv2.resize(x, (7,7)) for x in X_train])
X_test = np.array([cv2.resize(x, (7,7)) for x in X_test])
dataset = {}
dataset['train_input'] = torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset['test_input'] = torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(y_test).to(torch.float32).to(device)
model.train(dataset, opt="LBFGS", steps=20, batch=128) |
Were you able to actually train on MNIST using a flat dataset? |
No unfortunately, @WuZhuoran's comments make sense, I was solely making a point about getting training on mps to work |
Hi everybody, please let me know if anybody of you successfully applied KANs to any Computer vision tasks? or anybody integrated it with CNNs ? |
Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code. |
Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture. |
I tried something like this, but it didn’t work import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
train_data = torchvision.datasets.MNIST(
root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
root="./mnist_data", train=False, download=True, transform=None
)
valid_labels = [0, 1, 2]
X_train = []
y_train = []
for pil_img, label in train_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_train.append(x.astype(float))
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")
X_test = []
y_test = []
for pil_img, label in test_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_test.append(x.astype(float))
y_test.append(label)
X_test = np.array(X_test)
y_test = np.array(y_test)
X_test = (X_test - mean) / std
X_train = (X_train - mean) / std
device = "cpu"
model = KAN(width=[x.shape[0]**2, 20, 20, len(valid_labels)], grid=3, k=3, device=device)
dataset = {}
dataset["train_input"] = (
torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset["test_input"] = (
torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).to(torch.float32).to(device)
result = model.train(dataset, opt="Adam", steps=100, lr=0.1, batch=len(valid_labels), device=device)
plt.plot(result['train_loss'], label="train_loss")
plt.plot(result['test_loss'], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.show() |
My experience with MNIST is that a 2-Layer KAN with an extremely small (say 5 or 10) hidden neurons is enough to train MNIST (but maybe my impression was from accuracy), i.e., So please try computing acc as well. You can refer to this tutorial to see how to do this. Basically, it's something like def train_acc():
return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())
def test_acc():
return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
results['train_acc'][-1], results['test_acc'][-1] |
I am also willing to do so, but don't know how to integrate and train simultaneously |
Hi, I get around 73% test accuracy in about 1 minute. Playing with the network size may improve the performance. import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
def train_acc():
# model for some reason is on cpu only here, something about KAN's implementation
try:
arg = (
torch.argmax(model(dataset["train_input"]), dim=1) == dataset["train_label"]
)
except:
arg = torch.argmax(model(dataset["train_input"].to("cpu")), dim=1) == dataset[
"train_label"
].to("cpu")
return torch.mean(arg.float())
def test_acc():
try:
arg = torch.argmax(model(dataset["test_input"]), dim=1) == dataset["test_label"]
except:
arg = torch.argmax(model(dataset["test_input"].to("cpu")), dim=1) == dataset[
"test_label"
].to("cpu")
return torch.mean(arg.float())
train_data = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=None
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
valid_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
X_train = []
y_train = []
for pil_img, label in train_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_train.append(x.astype(float))
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")
X_test = []
y_test = []
for pil_img, label in test_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_test.append(x.astype(float))
y_test.append(label)
X_test = np.array(X_test)
y_test = np.array(y_test)
X_test = (X_test - mean) / std
X_train = (X_train - mean) / std
model = KAN(width=[x.shape[0] ** 2, 20, len(valid_labels)], grid=5, k=3, device=device)
dataset = {}
dataset["train_input"] = (
torch.flatten(torch.from_numpy(X_train), start_dim=1).long().to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).long().to(device)
dataset["test_input"] = (
torch.flatten(torch.from_numpy(X_test), start_dim=1).long().to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).long().to(device)
loss_fn = torch.nn.CrossEntropyLoss()
result = model.train(
dataset,
opt="Adam",
steps=50,
lr=0.1,
batch=512,
# metrics=(
# train_acc,
# test_acc,
# ), # this is the slower step, so its better to evaluate it after training
loss_fn=loss_fn,
# device=device,
)
acc = test_acc()
print(f"Test accuracy: {acc.item()}")
plt.plot(result["train_loss"], label="train_loss")
plt.plot(result["test_loss"], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.savefig("loss.png") |
Hi, here's my attempted code, its going to take about 30s to run on CUDA and get about 83% accuracy. import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
def preprocess_data(data):
images = []
labels = []
for img, label in data:
img = cv2.resize(np.array(img), (7, 7))
img = img.flatten() / 255.0
images.append(img)
labels.append(label)
return np.array(images), np.array(labels)
train_data = torchvision.datasets.MNIST(
root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
root="./mnist_data", train=False, download=True, transform=None
)
train_images, train_labels = preprocess_data(train_data)
test_images, test_labels = preprocess_data(test_data)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
dataset = {
"train_input": torch.from_numpy(train_images).float().to(device),
"train_label": torch.from_numpy(train_labels).to(device),
"test_input": torch.from_numpy(test_images).float().to("cpu"),
"test_label": torch.from_numpy(test_labels).to("cpu"),
}
model = KAN(width=[49, 10, 10], device=device)
results = model.train(
dataset,
opt="Adam",
lr=0.05,
steps=100,
batch=512,
loss_fn=torch.nn.CrossEntropyLoss(),
)
torch.save(model.state_dict(), "kan.pth")
del model
model = KAN(width=[49, 10, 10], device="cpu")
model.load_state_dict(torch.load("kan.pth"))
def test_acc():
with torch.no_grad():
predictions = torch.argmax(model(dataset["test_input"]), dim=1)
correct = (predictions == dataset["test_label"]).float()
accuracy = correct.mean()
return accuracy
acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")
plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()
plt.savefig("kan.png") |
I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance. |
Compared KAN with 5x smaller MLP. In 10 epochs, KAN reached 91% acc whereas MLP reached 97%. KAN loss goes down more slowly than that of MLP.
|
Hi, I did train on MNIST dataset but it is just flatten the image into 1D vector. I think we still need more development on Computer Vision Task. |
Good point on 2D conv layer. One possible is to define a Currently all the test on images (such as MNIST) that are processing the data into 1D vector which is not very useful. |
before use conv2d, what about use VAE latent space ,train KAN MNIST use VAE encode output as KAN input ? |
@xiaol use Handwritten Sequence Trajectories? |
I tried to replace MLP with KAN in CNN models, and the performances are close to each other. |
how can i build a Conv-KAN ? how do i integrate convolotion into KAN ? |
I used a 'linearized version' of |
I also tried a simple implementation of LeNet but with KAN as classifier: https://github.com/SimoSbara/kan-lenet KAN receives flatten data from convolution. |
I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:
After 4 epochs training, acc comes to 96%, the logs looks like:
|
In comparison with MLP its a good improvement. It would be nice to have a peformance benchmark for bigger nets where kan replaces mlp. |
I was expecting much more resistance to sudden local jumps from splines. this is what i would infer from the continual learning section of the paper as splines preserve local information much more than MLPs. I guess @KindXiaoming would have some idea about this. |
I would just guess that there might be a representation power limit given by how small and sudden the inflexions are, given the number of parameters we have for the splines? I would also like to investigate the actual reason tbh |
We implemented the KAN Convolutional Layers, check out our repo based in the efficient-kan implementation: |
your results also point out in the case of MNIST KAN isn't able to scale as much as promised in the paper, essentially being in the same level of MLP regarding parameters. |
Probably, after all, the representation power of KANs depends a lot on the distribution and shape of the data. Spline representation imposes some constraints to some shapes making them harder to represent, in contrast to MLPs which don't care that much about extreme shapes. There are still cases where I think KANs consistently outperform MLPs, but I guess it depends a lot on the data domain we are dealing with. I also saw some implementations that use RBFs instead of splines for KANs, I imagine that RBFs are kinda similar, they would be better compared to MLPs if our data contains gaussian shapes and has some normality. Also saw many more, that replace the kernel from the standard KAN with different polynomial representations (like Chebyshev) or even saw a wavelet kernel: But at the end of the day, I think all of these are biased towards better representing certain domains / data shapes and might not universally scale in all cases (depending on the data). It would be beneficial to make a synthetic data benchmark with some examples of consistent extreme shapes / gradients or other edge cases, and test all these architectures against MLPs. |
Hi, here I implement ConvKAN with different activation formulations with their corresponding inference time. https://github.com/XiangboGaoBarry/ConvKAN-Zoo |
We are trying to use KAN on ViT to replace MLP for training on ImageNet, and we welcome co-builds! |
Hi Author,
Thank you for your great work. I am wondering if we can apply this network on Vision based task such as classification/detection/segmentation, etc.
Thank you for your help.
The text was updated successfully, but these errors were encountered: