# autoencoder KAN

## MNIST

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import KAN

# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define Sparse Autoencoder model using KAN
class SparseKANAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, grid_size=5, spline_order=3):
        super(SparseKANAutoencoder, self).__init__()
        self.encoder = KAN([input_dim, hidden_dim, latent_dim], grid_size=grid_size, spline_order=spline_order)
        self.decoder = KAN([latent_dim, hidden_dim, input_dim], grid_size=grid_size, spline_order=spline_order)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Initialize model, criterion, optimizer, and scheduler
input_dim = 28 * 28
hidden_dim = 128
latent_dim = 64
model = SparseKANAutoencoder(input_dim, hidden_dim, latent_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    with tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for images, _ in pbar:
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, images)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            pbar.set_postfix(loss=loss.item())

    train_loss /= len(trainloader.dataset)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, _ in valloader:
            images = images.view(-1, 28 * 28).to(device)
            outputs = model(images)
            loss = criterion(outputs, images)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(valloader.dataset)
    print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}")

    # Update learning rate
    scheduler.step()

# Save the model for further analysis and explainability
torch.save(model.state_dict(), "sparse_kan_autoencoder.pth")


Epoch 1/10: 100%|██████████| 938/938 [00:34<00:00, 26.92it/s, loss=0.0569]


Epoch 1, Train Loss: 0.1139
Epoch 1, Val Loss: 0.0548


Epoch 2/10: 100%|██████████| 938/938 [00:34<00:00, 26.93it/s, loss=0.0422]


Epoch 2, Train Loss: 0.0473
Epoch 2, Val Loss: 0.0406


Epoch 3/10: 100%|██████████| 938/938 [00:35<00:00, 26.74it/s, loss=0.035]


Epoch 3, Train Loss: 0.0377
Epoch 3, Val Loss: 0.0348


Epoch 4/10: 100%|██████████| 938/938 [00:35<00:00, 26.58it/s, loss=0.0281]


Epoch 4, Train Loss: 0.0330
Epoch 4, Val Loss: 0.0312


Epoch 5/10: 100%|██████████| 938/938 [00:35<00:00, 26.68it/s, loss=0.0326]


Epoch 5, Train Loss: 0.0302
Epoch 5, Val Loss: 0.0291


Epoch 6/10: 100%|██████████| 938/938 [00:35<00:00, 26.70it/s, loss=0.025]


Epoch 6, Train Loss: 0.0283
Epoch 6, Val Loss: 0.0275


Epoch 7/10: 100%|██████████| 938/938 [00:34<00:00, 26.95it/s, loss=0.0243]


Epoch 7, Train Loss: 0.0269
Epoch 7, Val Loss: 0.0264


Epoch 8/10: 100%|██████████| 938/938 [00:34<00:00, 26.82it/s, loss=0.0245]


Epoch 8, Train Loss: 0.0259
Epoch 8, Val Loss: 0.0256


Epoch 9/10: 100%|██████████| 938/938 [00:35<00:00, 26.20it/s, loss=0.0263]


Epoch 9, Train Loss: 0.0252
Epoch 9, Val Loss: 0.0250


Epoch 10/10: 100%|██████████| 938/938 [00:35<00:00, 26.21it/s, loss=0.0246]


Epoch 10, Train Loss: 0.0246
Epoch 10, Val Loss: 0.0245
Training completed and model saved.
