# CIFAR10 pretrained -> KAN conversion

Download a pretrained CIFAR10 model, swap supported layers for their KAN equivalents (weights copied), and run a quick output check. No training is performed.

In [13]:
import torch
from converted_KAN.convert import convert_to_kan

In [14]:
# --- Configuration ---
HUB_REPO = "chenyaofo/pytorch-cifar-models"
HUB_MODEL = "cifar10_resnet20"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
EXAMPLE_INPUT_SHAPE = (2, 3, 32, 32)
BATCH_SIZE = 256


In [15]:
# --- Load pretrained CIFAR10 model from torch.hub ---
model = torch.hub.load(HUB_REPO, HUB_MODEL, pretrained=True, verbose=True)
model = model.to(device=device, dtype=dtype).eval()
print(model)


CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

Using cache found in /home/janis/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [16]:
# --- Download CIFAR10 test set ---
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

normalize = transforms.Normalize(
    mean=(0.4914, 0.4822, 0.4465),
    std=(0.2470, 0.2435, 0.2616),
)
test_transform = transforms.Compose([transforms.ToTensor(), normalize])
test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=test_transform)
pin_memory = device.type == "cuda"
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=pin_memory,
)
print(f"Test samples: {len(test_dataset)}")


Test samples: 10000


In [17]:
# --- Convert to KAN ---
kan_model = convert_to_kan(model, inplace=False).to(device)
kan_model.eval()
print(kan_model)


CifarResNet(
  (conv1): Conv2dKAN()
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2dKAN()
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2dKAN()
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2dKAN()
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2dKAN()
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2dKAN()
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2dKAN()
      (bn2): BatchNor

In [18]:
# --- Output parity check ---
torch.manual_seed(0)
example_input = torch.randn(*EXAMPLE_INPUT_SHAPE, device=device, dtype=dtype)
with torch.no_grad():
    out_orig = model(example_input)
    out_kan = kan_model(example_input)
close = torch.allclose(out_orig, out_kan, atol=1e-5, rtol=1e-5)
max_diff = (out_orig - out_kan).abs().max().item()
print("Outputs allclose:", close)
print("Max abs diff:", max_diff)


Outputs allclose: True
Max abs diff: 1.430511474609375e-06


In [19]:
# --- Inference timing (single batch) ---
import time

def benchmark(model, sample, iters=50, warmup=10):
    model.eval()
    with torch.no_grad():
        for _ in range(warmup):
            model(sample)
        if device.type == "cuda":
            torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(iters):
            model(sample)
        if device.type == "cuda":
            torch.cuda.synchronize()
        end = time.perf_counter()
    return (end - start) / iters

torch.manual_seed(1)
bench_input = torch.randn(*EXAMPLE_INPUT_SHAPE, device=device, dtype=dtype)
orig_ms = benchmark(model, bench_input) * 1000
kan_ms = benchmark(kan_model, bench_input) * 1000
print(f"Orig latency (ms/batch): {orig_ms:.3f}")
print(f"KAN latency (ms/batch): {kan_ms:.3f}")


Orig latency (ms/batch): 7.259
KAN latency (ms/batch): 19.339


In [20]:
# --- Evaluate accuracy on CIFAR10 test set ---
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device=device, dtype=dtype)
            labels = labels.to(device)
            logits = model(images)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    return correct / total if total > 0 else float("nan")

acc_orig = evaluate(model, test_loader)
acc_kan = evaluate(kan_model, test_loader)
print(f"Original acc: {acc_orig:.4f} | KAN acc: {acc_kan:.4f}")


Original acc: 0.9212 | KAN acc: 0.9212
