In [3]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np

class ChebyshevKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(ChebyshevKANLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.degree = degree

        self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.xavier_normal_(self.cheby_coeffs)
        self.register_buffer("arange", torch.arange(0, degree + 1, 1))

    def chebyshev_polynomials(self, x):
        T = [torch.ones_like(x), x]
        for n in range(2, self.degree + 1):
            T.append(2 * x * T[n - 1] - T[n - 2])
        return torch.stack(T, dim=-1)

    def forward(self, x):
        x = x.view(-1, self.inputdim)
        x = 2 * (x - x.min(dim=1, keepdim=True)[0]) / (x.max(dim=1, keepdim=True)[0] - x.min(dim=1, keepdim=True)[0]) - 1
        T = self.chebyshev_polynomials(x)
        y = torch.einsum("bij,ioj->bo", T, self.cheby_coeffs)
        y = y.view(-1, self.outdim)
        return y

class KAN(nn.Module):
    def __init__(self, layers_hidden, degree=3):
        super(KAN, self).__init__()
        self.layers = nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(ChebyshevKANLayer(in_features, out_features, degree))

    def forward(self, x: torch.Tensor):
        for layer in self.layers:
            x = layer(x)
        return x

class MultiheadKANAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, rotation_matrix, degree=3):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.position_emb = rotation_matrix

        self.qkv_linear = ChebyshevKANLayer(hidden_size, hidden_size * 3, degree)
        self.out = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_length, hidden_size = x.size()
        qkv = self.qkv_linear(x)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.transpose(1, 2)
        queries, keys, values = qkv.chunk(3, dim=-1)
        queries = apply_rotary_pos_emb(queries, self.position_emb)
        keys = apply_rotary_pos_emb(keys, self.position_emb)
        scores = torch.matmul(queries, keys.transpose(2, 3))
        scores = scores / (self.head_dim ** 0.5)
        attention = F.softmax(scores, dim=-1)
        context = torch.matmul(attention, values)
        context = context.transpose(1, 2)
        context = context.reshape(batch_size, seq_length, hidden_size)
        output = self.out(context)
        return output

class KANFormer(nn.Module):
    def __init__(self, num_features, hidden_size, num_heads, n_blocks, ff_dims, max_seq_len, device, degree=3):
        super().__init__()
        self.embedding = nn.Linear(num_features, hidden_size)
        head_dim = hidden_size // num_heads
        rope = RotaryPositionalEmbedding(head_dim, max_seq_len)
        rotation_matrix = rope(max_seq_len).to(device)
        self.blocks = nn.ModuleList([KANBlock(hidden_size, num_heads, rotation_matrix, degree) for _ in range(n_blocks)])
        self.ff = nn.ModuleList()
        in_size = max_seq_len * hidden_size
        for f in ff_dims:
            self.ff.append(ChebyshevKANLayer(in_size, f, degree))
            in_size = f

    def forward(self, x):
        x = self.embedding(x)
        for block in self.blocks:
            x = block(x)
        x = x.flatten(start_dim=1)
        for f in self.ff:
            x = f(x)
        return x

class KANBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, rotation_matrix, degree=3):
        super().__init__()
        self.norm1 = RMSNorm(hidden_size)
        self.attention = MultiheadKANAttention(hidden_size, num_heads, rotation_matrix, degree)

    def forward(self, x):
        x1 = self.attention(self.norm1(x))
        out = x + x1
        return out

class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len):
        super(RotaryPositionalEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.register_buffer('pos_enc', self._generate_positional_encoding(max_seq_len))

    def _generate_positional_encoding(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        pos_enc = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return pos_enc

    def forward(self, seq_len):
        return self.pos_enc[:seq_len, :]

def apply_rotary_pos_emb(x, pos_emb):
    x_cos, x_sin = torch.split(pos_emb, x.shape[-1] // 2, dim=-1)
    x1_rot = (x[..., ::2] * x_cos) + (rotate_half(x[..., 1::2]) * x_sin)
    x2_rot = (x[..., 1::2] * x_cos) + (rotate_half(x[..., ::2]) * x_sin)
    x_rot = torch.cat([x1_rot, x2_rot], dim=-1)
    return x_rot

def rotate_half(x):
    x1, x2 = torch.chunk(x, 2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


In [4]:
start=time.time()
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim

# Define the transformations for the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the MNIST dataset
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('.', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Adjust the model to handle the MNIST input size
class KANFormerMNIST(KANFormer):
    def __init__(self, hidden_size, num_heads, n_blocks, ff_dims, max_seq_len, device, degree=3):
        super().__init__(28, hidden_size, num_heads, n_blocks, ff_dims, max_seq_len, device, degree)

    def forward(self, x):
        # Flatten the input images and treat each row as a sequence
        x = x.view(x.size(0), 28, 28)
        return super().forward(x)

# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = KANFormerMNIST(64, 8, 4, [128, 64], 28, device, 3).to(device)

# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

# Train the model
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Evaluate the model
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run the training and testing loop
for epoch in range(1, 11):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
stop=time.time()
print(stop-start)













Test set: Average loss: 0.0002, Accuracy: 9564/10000 (96%)















Test set: Average loss: 0.0001, Accuracy: 9636/10000 (96%)















Test set: Average loss: 0.0001, Accuracy: 9707/10000 (97%)















Test set: Average loss: 0.0001, Accuracy: 9750/10000 (98%)















Test set: Average loss: 0.0001, Accuracy: 9783/10000 (98%)















Test set: Average loss: 0.0001, Accuracy: 9804/10000 (98%)















Test set: Average loss: 0.0001, Accuracy: 9778/10000 (98%)















Test set: Average loss: 0.0001, Accuracy: 9805/10000 (98%)















Test set: Average loss: 0.0001, Accuracy: 9802/10000 (98%)















Test set: Average loss: 0.0001, Accuracy: 9823/10000 (98%)



1101.1623513698578


#### The model is not optimized