In [13]:
import torch

# Check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [14]:
from einops.layers.torch import Rearrange
from torch import nn
from vision_models_playground.external.kan.kan import KAN

model = nn.Sequential(
    Rearrange("b ... -> b (...)"),
    KAN(
        width=[28*28, 16, 10],
        grid=3,
        k=3,
        device=device,
        symbolic_enabled=False
    ),
    nn.Softmax(dim=1)
)

In [15]:
from vision_models_playground.train import train_model_classifier
from vision_models_playground.datasets import get_mnist_dataset

train_dataset, valid_dataset = get_mnist_dataset()
train_model_classifier(model, train_dataset, valid_dataset, num_epochs=100, num_workers=2)

[32mTrain Epoch: 99, Step:     937 | MulticlassAccuracy: 0.9323 | MulticlassF1Score: 0.9323 | MulticlassMatthewsCorrCoef: 0.9248 | LossTracker: 1.5417 | : 100%|██████████| 938/938 [00:11<00:00, 83.43it/s]
[33mValid Epoch: 99, Step:     156 | MulticlassAccuracy: 0.9239 | MulticlassF1Score: 0.9239 | MulticlassMatthewsCorrCoef: 0.9155 | LossTracker: 1.5480 | : 100%|██████████| 157/157 [00:01<00:00, 98.03it/s] 


In [16]:
model = nn.Sequential(
    Rearrange("b ... -> b (...)"),
    KAN(
        width=[28*28, 256, 128, 10],
        grid=3,
        k=3,
        device=device,
        symbolic_enabled=False
    ),
    nn.Softmax(dim=1)
)

In [17]:
train_dataset, valid_dataset = get_mnist_dataset()
train_model_classifier(model, train_dataset, valid_dataset, num_epochs=100, num_workers=2)

[32mTrain Epoch: 87, Step:     937 | MulticlassAccuracy: 0.9829 | MulticlassF1Score: 0.9829 | MulticlassMatthewsCorrCoef: 0.9810 | LossTracker: 1.4793 | : 100%|██████████| 938/938 [01:53<00:00,  8.27it/s]
[33mValid Epoch: 87, Step:     156 | MulticlassAccuracy: 0.9685 | MulticlassF1Score: 0.9685 | MulticlassMatthewsCorrCoef: 0.9650 | LossTracker: 1.4939 | : 100%|██████████| 157/157 [00:14<00:00, 10.61it/s]
[32mTrain Epoch: 88, Step:     937 | MulticlassAccuracy: 0.9830 | MulticlassF1Score: 0.9830 | MulticlassMatthewsCorrCoef: 0.9812 | LossTracker: 1.4792 | : 100%|██████████| 938/938 [01:53<00:00,  8.26it/s]
[33mValid Epoch: 88, Step:     156 | MulticlassAccuracy: 0.9686 | MulticlassF1Score: 0.9686 | MulticlassMatthewsCorrCoef: 0.9651 | LossTracker: 1.4938 | : 100%|██████████| 157/157 [00:14<00:00, 10.61it/s]
[32mTrain Epoch: 89, Step:     937 | MulticlassAccuracy: 0.9832 | MulticlassF1Score: 0.9832 | MulticlassMatthewsCorrCoef: 0.9814 | LossTracker: 1.4790 | : 100%|██████████| 938

In [18]:
model = nn.Sequential(
    Rearrange("b ... -> b (...)"),
    KAN(
        width=[28*28, 64, 32, 10],
        grid=3,
        k=3,
        device=device,
        symbolic_enabled=False
    ),
    nn.Softmax(dim=1)
)

In [19]:
train_dataset, valid_dataset = get_mnist_dataset()
train_model_classifier(model, train_dataset, valid_dataset, num_epochs=100, num_workers=2)

[33mValid Epoch: 67, Step:     156 | MulticlassAccuracy: 0.9612 | MulticlassF1Score: 0.9612 | MulticlassMatthewsCorrCoef: 0.9568 | LossTracker: 1.5041 | : 100%|██████████| 157/157 [00:03<00:00, 39.56it/s]
[32mTrain Epoch: 68, Step:     937 | MulticlassAccuracy: 0.9732 | MulticlassF1Score: 0.9732 | MulticlassMatthewsCorrCoef: 0.9703 | LossTracker: 1.4923 | : 100%|██████████| 938/938 [00:28<00:00, 33.35it/s]
[33mValid Epoch: 68, Step:     156 | MulticlassAccuracy: 0.9613 | MulticlassF1Score: 0.9613 | MulticlassMatthewsCorrCoef: 0.9570 | LossTracker: 1.5039 | : 100%|██████████| 157/157 [00:03<00:00, 40.28it/s]
[32mTrain Epoch: 69, Step:     937 | MulticlassAccuracy: 0.9735 | MulticlassF1Score: 0.9735 | MulticlassMatthewsCorrCoef: 0.9706 | LossTracker: 1.4919 | : 100%|██████████| 938/938 [00:28<00:00, 33.40it/s]
[33mValid Epoch: 69, Step:     156 | MulticlassAccuracy: 0.9614 | MulticlassF1Score: 0.9614 | MulticlassMatthewsCorrCoef: 0.9571 | LossTracker: 1.5037 | : 100%|██████████| 157