In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm import tqdm

from lkan.models import KANConv2d, KANLinear2

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [16]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 10),
        )

    def forward(self, x):
        x = self.layers(x)
        return x


class KAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            KANConv2d(
                in_channels=1, out_channels=4, kernel_size=5, stride=1, padding=2
            ),
            nn.MaxPool2d(2),
            KANConv2d(4, 8, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            KANLinear2(8 * 7 * 7, 10),
        )

    def forward(self, x):
        x = self.layers(x)
        return x

In [17]:
data_dir = "../.data/"
batch_size = 64
split_ratio = 0.8

lr = 0.002
epochs = 5

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)

dataset = torchvision.datasets.MNIST(data_dir, transform=transform, download=True)
ds_train, ds_val = torch.utils.data.random_split(
    dataset,
    [int(len(dataset) * split_ratio), len(dataset) - int(len(dataset) * split_ratio)],
)
loader_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False)
loader_train = torch.utils.data.DataLoader(
    ds_train, batch_size=batch_size, shuffle=True
)

In [18]:
# model = MLP().to(device)
model = KAN().to(device)

counter = 0
for param in model.parameters():
    counter += param.numel()

print(f"Number of parameters: {counter}")

Number of parameters: 43392


In [19]:
opt = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    print("Training...")
    train_acc = 0
    avg_train_loss = 0
    for x, y in tqdm(loader_train):
        x, y = x.to(device), y.to(device)

        opt.zero_grad()

        y_pred = model(x)
        loss = nn.CrossEntropyLoss()(y_pred, y)

        loss.backward()
        opt.step()

        train_acc += (y_pred.argmax(1) == y).float().mean().item()
        avg_train_loss += loss.item()

    avg_train_loss /= len(loader_train)
    train_acc /= len(loader_train)

    print("Validation...")
    
    with torch.no_grad():
        acc = 0
        avg_loss = 0
        for x, y in tqdm(loader_val):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = nn.CrossEntropyLoss()(y_pred, y)
            acc += (y_pred.argmax(1) == y).float().mean().item()
            avg_loss += loss.item()
        acc /= len(loader_val)
        avg_loss /= len(loader_val)
        print(
            f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss} - Val Loss: {avg_loss} - Train Acc: {train_acc} -Val Acc: {acc}"
        )

print("Done!")

Training...


  0%|          | 2/750 [00:00<00:46, 15.99it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  1%|          | 6/750 [00:00<00:42, 17.53it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  1%|▏         | 10/750 [00:00<00:43, 16.84it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  2%|▏         | 14/750 [00:00<00:42, 17.25it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  2%|▏         | 18/750 [00:01<00:41, 17.50it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  3%|▎         | 22/750 [00:01<00:41, 17.38it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  3%|▎         | 26/750 [00:01<00:41, 17.42it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  4%|▍         | 30/750 [00:01<00:41, 17.38it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  5%|▍         | 34/750 [00:01<00:40, 17.56it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  5%|▌         | 38/750 [00:02<00:40, 17.46it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  6%|▌         | 42/750 [00:02<00:41, 17.17it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  6%|▌         | 46/750 [00:02<00:40, 17.24it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  7%|▋         | 50/750 [00:02<00:40, 17.48it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  7%|▋         | 54/750 [00:03<00:39, 17.56it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])


  8%|▊         | 57/750 [00:03<00:40, 17.08it/s]

torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])
torch.Size([64, 1, 28, 28])
torch.Size([64, 4, 14, 14])





KeyboardInterrupt: 