In [1]:
# TaylorKAN + CNN
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class TaylorLayer(nn.Module):
  def __init__(self, input_dim, out_dim, order, addbias=True):
    super(TaylorLayer, self).__init__()
    self.input_dim = input_dim
    self.out_dim = out_dim
    self.order = order
    self.addbias = addbias

    # Initialize Taylor coefficients
    self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order) * 0.01)  
    if self.addbias:
      self.bias = nn.Parameter(torch.zeros(1, out_dim)) 

  def forward(self, x):
    shape = x.shape
    outshape = shape[0:-1] + (self.out_dim,)
    x = torch.reshape(x, (-1, self.input_dim)) 

    x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)

    # Compute and accumulate each term of the Taylor expansion
    y = torch.zeros((x.shape[0], self.out_dim), device=x.device)  

    for i in range(self.order):
      term = (x_expanded ** i) * self.coeffs[:, :, i]  
      y += term.sum(dim=-1) 

    if self.addbias:
      y += self.bias  

    y = torch.reshape(y, outshape)
    return y

class TaylorCNN(nn.Module):
  def __init__(self):
    super(TaylorCNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
    self.pool1 = nn.MaxPool2d(2)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
    self.pool2 = nn.MaxPool2d(2)
    self.taylorkan1 = TaylorLayer(32*7*7, 128, 2)
    self.taylorkan2 = TaylorLayer(128, 10, 2)

  def forward(self, x):
    x = F.selu(self.conv1(x))
    x = self.pool1(x)
    x = F.selu(self.conv2(x))
    x = self.pool2(x)
    x = x.view(x.size(0), -1)
    x = self.taylorkan1(x)
    x = self.taylorkan2(x)
    return x


# Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TaylorCNN().to(device)

# optimizer = optim.LBFGS(model.parameters(), lr=0.001)
# optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-4, momentum=0.9)
optimizer = optim.RAdam(model.parameters(), lr=0.0001)



# Training
def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for i, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data.to(device))
    loss = nn.CrossEntropyLoss()(output, target.to(device))
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
      print(f'Train Epoch: {epoch} [{i * len(data)}/{len(train_loader.dataset)} ({100. * i / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Evaluation
def evaluate(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += nn.CrossEntropyLoss()(output, target).item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Running
for epoch in range(0, 2):
  train(model, device, train_loader, optimizer, epoch)
evaluate(model, device, test_loader)

  from .autonotebook import tqdm as notebook_tqdm


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [01:22<00:00, 120275.15it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 90122.48it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:13<00:00, 118485.01it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 130117.67it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw


Test set: Average loss: 0.0006, Accuracy: 9561/10000 (96%)

