In [None]:
# TaylorKAN + CNN
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch.nn.functional as F

class TaylorLayer(nn.Module):
  def __init__(self, input_dim, out_dim, order, addbias=True):
    super(TaylorLayer, self).__init__()
    self.input_dim = input_dim
    self.out_dim = out_dim
    self.order = order
    self.addbias = addbias

    # 初始化泰勒系数
    self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order) * 0.01)  # 维度: (outdim, inputdim, order)
    if self.addbias:
      self.bias = nn.Parameter(torch.zeros(1, out_dim))  # 维度: (1, outdim)

  def forward(self, x):
    shape = x.shape
    outshape = shape[0:-1] + (self.out_dim,)
    x = torch.reshape(x, (-1, self.input_dim))  # 重塑x为二维张量，形状: (batch_size, inputdim)

    # 扩展x以便与coeffs维度对齐
    x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)  # 形状: (batch_size, outdim, inputdim)

    # 计算泰勒展开的每一项并累加
    y = torch.zeros((x.shape[0], self.out_dim), device=x.device)  # 初始化输出 (batch_size, outdim)

    for i in range(self.order):
      term = (x_expanded ** i) * self.coeffs[:, :, i]  # 计算第i阶项并乘以相应系数
      y += term.sum(dim=-1)  # 将第i阶项的每个inputdim的贡献相加

    if self.addbias:
      y += self.bias  # 加上偏置

    y = torch.reshape(y, outshape)
    return y

class TaylorCNN(nn.Module):
  def __init__(self):
    super(TaylorCNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
    self.pool1 = nn.MaxPool2d(2)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
    self.pool2 = nn.MaxPool2d(2)
    self.taylorkan1 = TaylorLayer(32*7*7, 128, 2)
    self.taylorkan2 = TaylorLayer(128, 10, 2)

  def forward(self, x):
    # 卷积
    x = F.selu(self.conv1(x))
    # 池化
    x = self.pool1(x)
    # 卷积
    x = F.selu(self.conv2(x))
    # 池化
    x = self.pool2(x)
    # 将特征图展平成二维向量
    x = x.view(x.size(0), -1)
    # KAN层
    x = self.taylorkan1(x)
    # KAN层
    x = self.taylorkan2(x)
    return x


# Dataset and DataLoader setup
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
# subset_indices = np.random.choice(len(train_dataset), int(len(train_dataset) * 0.1), replace=False)
# train_subset = Subset(train_dataset, subset_indices)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model, Optimizer and Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TaylorCNN().to(device)
# optimizer = optim.LBFGS(model.parameters(), lr=0.001)
# optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-4, momentum=0.9) # 似乎只能使用LBFGS优化器
optimizer = optim.RAdam(model.parameters(), lr=0.0001)



# Training function
def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for i, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data.to(device))
    loss = nn.CrossEntropyLoss()(output, target.to(device))
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
      print(f'Train Epoch: {epoch} [{i * len(data)}/{len(train_loader.dataset)} ({100. * i / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Evaluation function
def evaluate(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += nn.CrossEntropyLoss()(output, target).item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Running training and evaluation
for epoch in range(0, 2):
  train(model, device, train_loader, optimizer, epoch)
evaluate(model, device, test_loader)