# KAN: Kolmogorov–Arnold Networks With ......

Implementation of KAN architecture with proper feature extraction and attention mechanisms for MRI image analysis.

## 1. Import Dependencies

Import required libraries and modules for implementing KAN.

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

## 2. Data Loading and Preprocessing

Setup data pipelines with proper transforms for MRI images.

In [2]:
# Data directories
TRAIN_DIR = "/home/mhs/research/thesis/Brain MRI ND-5 Dataset/Training"
TEST_DIR  = "/home/mhs/research/thesis/Brain MRI ND-5 Dataset/Testing"


# Define image transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((64, 64)),                # Resize for backbone
    transforms.ToTensor(),                        # Convert to tensor
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Repeat grayscale to 3 channels
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])   # ImageNet normalization
])

from torch.utils.data import Subset
import random

# Create datasets
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=transform)

# Create subsets properly
# Method 1: Random selection of indices
train_indices = random.sample(range(len(train_dataset)), min(1000, len(train_dataset)))
test_indices = random.sample(range(len(test_dataset)), min(500, len(test_dataset)))

# Create subset datasets
train_dataset = Subset(train_dataset, train_indices)
test_dataset = Subset(test_dataset, test_indices)

# Alternative Method 2: Take the first N samples
# train_dataset = Subset(train_dataset, list(range(min(1000, len(train_dataset)))))
# test_dataset = Subset(test_dataset, list(range(min(500, len(test_dataset)))))

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                         num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=4, pin_memory=True)

# When printing class counts, we need to access the original dataset's classes
original_train_dataset = datasets.ImageFolder(TRAIN_DIR)
num_classes = len(original_train_dataset.classes)
print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")


Number of classes: 4
Training samples: 1000
Testing samples: 500


## 3. Define KAN Architecture

Implement the Kernel Attention Network model with proper feature extraction.

In [3]:
# class KernelAttention(nn.Module):
#     def __init__(self, in_dim, kernel_size=7):
#         super().__init__()
#         self.conv = nn.Conv2d(in_dim, in_dim, kernel_size=kernel_size, 
#                              padding=kernel_size//2, groups=in_dim)
#         self.spatial_gate = nn.Sequential(
#             nn.Conv2d(in_dim, 1, kernel_size=1),
#             nn.Sigmoid()
#         )
        
#     def forward(self, x):
#         # Local feature aggregation
#         local_feat = self.conv(x)
#         # Generate attention weights
#         attn = self.spatial_gate(local_feat)
#         return x * attn

# class KANModel(nn.Module):
#     def __init__(self, num_classes, backbone='resnet18'):
#         super().__init__()
        
#         # 1. Feature Extraction Backbone
#         if backbone == 'resnet18':
#             base = models.resnet18(pretrained=True)
#             self.feature_dim = 512
#         else:
#             raise ValueError(f"Unsupported backbone: {backbone}")
            
#         # Remove the final FC layer
#         self.features = nn.Sequential(*list(base.children())[:-2])
        
#         # 2. Kernel Attention Module
#         self.attention = KernelAttention(self.feature_dim)
        
#         # 3. Global Average Pooling
#         self.gap = nn.AdaptiveAvgPool2d(1)
        
#         # 4. Classifier
#         self.classifier = nn.Sequential(
#             nn.Linear(self.feature_dim, 256),
#             nn.ReLU(inplace=True),
#             nn.Dropout(0.5),
#             nn.Linear(256, num_classes)
#         )
        
#     def forward(self, x):
#         # Extract features
#         x = self.features(x)  # [B, 512, H', W']
        
#         # Apply kernel attention
#         x = self.attention(x)
        
#         # Global average pooling
#         x = self.gap(x)      # [B, 512, 1, 1]
#         x = x.view(x.size(0), -1)  # [B, 512]
        
#         # Classification
#         return self.classifier(x)

In [4]:
# class BSpline(nn.Module):
#     """B-spline implementation for KAN."""

#     def __init__(self, in_dim, grid_size=5, degree=3):
#         super().__init__()
#         self.in_dim = in_dim
#         self.grid_size = grid_size
#         self.degree = degree

#         # Learnable control points - one set for each channel
#         self.control_points = nn.Parameter(torch.randn(in_dim, grid_size))

#         # Fixed grid points from 0 to 1
#         self.register_buffer("grid", torch.linspace(0, 1, grid_size))

#     def forward(self, x):
#         # x shape: [B, C, N]
#         B, C, N = x.shape

#         # Ensure C matches in_dim
#         assert (
#             C == self.in_dim
#         ), f"Input has {C} channels but BSpline expects {self.in_dim}"

#         # Normalize input to [0, 1]
#         x_min = x.min(dim=2, keepdim=True)[0]
#         x_max = x.max(dim=2, keepdim=True)[0]
#         x_norm = (x - x_min) / (x_max - x_min + 1e-8)

#         # Initialize output tensor
#         out = torch.zeros_like(x)

#         # Compute B-spline weights for each channel
#         for c in range(C):
#             # Expand grid for broadcasting
#             grid_expanded = self.grid.view(1, -1)  # [1, grid_size]
#             x_expanded = x_norm[:, c, :].unsqueeze(2)  # [B, N, 1]

#             # Compute weights using RBF
#             weights = torch.exp(
#                 -((x_expanded - grid_expanded) ** 2) / 0.1
#             )  # [B, N, grid_size]
#             weights = weights / (weights.sum(dim=2, keepdim=True) + 1e-8)

#             # Apply weights to control points
#             out[:, c, :] = torch.matmul(
#                 weights, self.control_points[c].unsqueeze(1)
#             ).squeeze(2)

#         return out


# class KANAttention(nn.Module):
#     def __init__(self, in_dim, grid_size=5, degree=3):
#         super().__init__()
#         self.in_dim = in_dim

#         # Modified KAN layer
#         self.spline = BSpline(in_dim, grid_size, degree)
#         self.attention_conv = nn.Conv1d(in_dim, 1, 1)
#         self.activation = nn.Sigmoid()

#     def forward(self, x):
#         # x: [B, C, H, W]
#         B, C, H, W = x.shape

#         # Reshape to [B, C, H*W]
#         x_flat = x.view(B, C, -1)

#         # Apply B-spline transformation
#         x_spline = self.spline(x_flat)

#         # Generate attention weights
#         attn = self.attention_conv(x_spline)  # [B, 1, H*W]
#         attn = self.activation(attn)

#         # Reshape attention back to spatial dimensions
#         attn = attn.view(B, 1, H, W)

#         # Apply attention
#         return x * attn.expand_as(x)


# class KANModel(nn.Module):
#     def __init__(self, num_classes, backbone="mobilenetv3"):
#         super().__init__()

#         # 1. Feature Extraction Backbone
#         if backbone == "resnet18":
#             base = models.resnet18(pretrained=True)
#             self.feature_dim = 512
#             # Remove final FC layer and keep feature extractor
#             self.features = nn.Sequential(*list(base.children())[:-2])
#         elif backbone == "mobilenetv3":
#             base = models.mobilenet_v3_small(
#                 weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
#             )
#             # For MobileNetV3, we need to determine the feature dimension correctly
#             self.feature_dim = (
#                 576  # MobileNetV3-Small has 576 channels in its last conv layer
#             )
#             # Replace classifier but keep features
#             self.features = base.features
#         else:
#             raise ValueError(f"Unsupported backbone: {backbone}")

#         # 2. KAN Attention Module
#         self.attention = KANAttention(self.feature_dim)

#         # 3. Global Average Pooling
#         self.gap = nn.AdaptiveAvgPool2d(1)

#         # 4. Final Classifier
#         self.classifier = nn.Sequential(
#             nn.Linear(self.feature_dim, 256),
#             nn.BatchNorm1d(256),
#             nn.ReLU(inplace=True),
#             nn.Dropout(0.5),
#             nn.Linear(256, num_classes),
#         )

#     def forward(self, x):
#         # Feature extraction
#         x = self.features(x)  # [B, feature_dim, H', W']

#         # Apply KAN attention
#         x = self.attention(x)

#         # Global average pooling
#         x = self.gap(x)  # [B, feature_dim, 1, 1]
#         x = x.view(x.size(0), -1)  # [B, feature_dim]

#         # Classification
#         return self.classifier(x)

In [5]:
# class KANModel(nn.Module):
#     def __init__(self, num_classes, backbone="mobilenetv3"):
#         super().__init__()

#         # 1. Feature Extraction Backbone
#         if backbone == "resnet18":
#             base = models.resnet18(pretrained=True)
#             self.feature_dim = 512
#         elif backbone == "mobilenetv3":
#             base = models.mobilenet_v3_small(
#                 weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
#             )
#             self.feature_dim = 512
#         else:
#             raise ValueError(f"Unsupported backbone: {backbone}")

#         # Remove final FC layer and keep feature extractor
#         self.features = nn.Sequential(*list(base.children())[:-2])

#         # 2. KAN Attention Module
#         self.attention = KANAttention(self.feature_dim)

#         # 3. Global Average Pooling
#         self.gap = nn.AdaptiveAvgPool2d(1)

#         # 4. Final Classifier
#         self.classifier = nn.Sequential(
#             nn.Linear(self.feature_dim, 256),
#             nn.BatchNorm1d(256),
#             nn.ReLU(inplace=True),
#             nn.Dropout(0.5),
#             nn.Linear(256, num_classes),
#         )

#     def forward(self, x):
#         # Feature extraction
#         x = self.features(x)  # [B, 512, H', W']

#         # Apply KAN attention
#         x = self.attention(x)

#         # Global average pooling
#         x = self.gap(x)  # [B, 512, 1, 1]
#         x = x.view(x.size(0), -1)  # [B, 512]

#         # Classification
#         return self.classifier(x)

In [6]:
class SpatialBSpline(nn.Module):
    """2D B-spline that processes spatial dimensions directly."""
    def __init__(self, in_channels, out_channels, grid_size=5):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_size = grid_size
        
        # Control points for each input-output channel pair
        self.control_points = nn.Parameter(torch.randn(in_channels, out_channels, grid_size * grid_size))
        
        # Create 2D grid
        x = torch.linspace(0, 1, grid_size)
        y = torch.linspace(0, 1, grid_size)
        xx, yy = torch.meshgrid(x, y, indexing="ij")
        grid = torch.stack([xx.flatten(), yy.flatten()], dim=1)
        self.register_buffer("grid", grid)  # [grid_size^2, 2]
        
    def forward(self, x):
        # x: [B, C_in, H, W]
        B, C_in, H, W = x.shape
        
        # Create output tensor
        out = torch.zeros(B, self.out_channels, H, W, device=x.device)
        
        # Normalize spatial values to [0,1] for grid lookup
        h_norm = torch.linspace(0, 1, H, device=x.device)
        w_norm = torch.linspace(0, 1, W, device=x.device)
        
        # Create normalized 2D position grid for the image
        norm_h, norm_w = torch.meshgrid(h_norm, w_norm, indexing="ij")
        positions = torch.stack([norm_h.flatten(), norm_w.flatten()], dim=1)  # [H*W, 2]
        
        # Process each position with B-spline
        for c_in in range(C_in):
            # Get channel data and normalize to [0,1]
            x_c = x[:, c_in]  # [B, H, W]
            x_flat = x_c.reshape(B, -1)  # [B, H*W]
            x_min = x_flat.min(dim=1, keepdim=True)[0]
            x_max = x_flat.max(dim=1, keepdim=True)[0] 
            x_norm = (x_flat - x_min) / (x_max - x_min + 1e-8)  # [B, H*W]
            
            # Compute weights based on positions and pixel values
            pos_weights = torch.zeros(B, H*W, self.grid_size*self.grid_size, device=x.device)
            
            for b in range(B):
                # Compute spatial distances to grid points
                dist_to_grid = torch.cdist(positions, self.grid, p=2)  # [H*W, grid_size^2]
                
                # Combine with pixel value influence
                x_influence = x_norm[b].unsqueeze(1).expand(-1, self.grid_size*self.grid_size)
                combined_dist = dist_to_grid * (1.0 + x_influence)
                
                # RBF kernel
                weights = torch.exp(-combined_dist / 0.2)  # [H*W, grid_size^2]
                weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)
                pos_weights[b] = weights
            
            # Apply weights to control points for each output channel
            for c_out in range(self.out_channels):
                control_pts = self.control_points[c_in, c_out]  # [grid_size^2]
                # Weighted sum for each position
                weighted_vals = torch.matmul(pos_weights, control_pts)  # [B, H*W]
                out[:, c_out] += weighted_vals.view(B, H, W)
                
        return out

class PureBSplineBlock(nn.Module):
    """Block using only B-splines with no convolutional layers."""
    def __init__(self, in_channels, out_channels, grid_size=5):
        super().__init__()
        
        # Feature transformation with B-splines
        self.spatial_spline = SpatialBSpline(
            in_channels, out_channels, grid_size=grid_size
        )
        
        # Normalization and activation
        # Change this line - normalize only over the channel dimension
        self.norm = nn.LayerNorm(out_channels)  
        self.act = nn.GELU()
        
        # B-spline for attention weights
        self.attention_spline = SpatialBSpline(
            out_channels, 1, grid_size=grid_size
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: [B, C, H, W]
        
        # Apply spatial B-spline
        features = self.spatial_spline(x)
        
        # Layer normalization across channels
        B, C, H, W = features.shape
        features = features.permute(0, 2, 3, 1)  # [B, H, W, C]
        features = self.norm(features)  # This now applies across the last dimension only
        features = features.permute(0, 3, 1, 2)  # [B, C, H, W]
        features = self.act(features)
        
        # Apply attention
        attention = self.attention_spline(features)
        attention = self.sigmoid(attention)
        
        # Apply attention weights
        return features * attention

class PureKANModel(nn.Module):
    """Image classifier using only B-splines with NO convolutional layers."""
    def __init__(self, num_classes, in_channels=3):
        super().__init__()
        
        # Initial feature extraction
        self.input_proj = nn.Linear(in_channels, 32)
        
        # B-spline blocks
        self.block1 = PureBSplineBlock(32, 64, grid_size=7)
        self.pool1 = nn.AvgPool2d(2)
        
        self.block2 = PureBSplineBlock(64, 128, grid_size=7)
        self.pool2 = nn.AvgPool2d(2)
        
        self.block3 = PureBSplineBlock(128, 256, grid_size=9)
        self.pool3 = nn.AvgPool2d(2)
        
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape
        
        # Initial projection
        x = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        x = self.input_proj(x)  # [B, H, W, 32]
        x = x.permute(0, 3, 1, 2)  # [B, 32, H, W]
        
        # Process through B-spline blocks
        x = self.block1(x)
        x = self.pool1(x)
        
        x = self.block2(x)
        x = self.pool2(x)
        
        x = self.block3(x)
        x = self.pool3(x)
        
        # Global pooling
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        
        # Classification
        return self.classifier(x)

## 4. Training Configuration

Setup training parameters and optimization configuration.

In [7]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = PureKANModel(num_classes=num_classes, in_channels=3).to(device)
# Possibly lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)

# Initialize model
# model = KANModel(num_classes=num_classes).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Learning rate scheduler (removed verbose parameter)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.5, patience=5
)

Using device: cuda


## 5. Training and Evaluation Loop

Implement the main training loop with evaluation metrics.

In [8]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc='Training'):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Compute epoch metrics
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Compute metrics
    avg_loss = running_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, 
                                target_names=train_dataset.classes)
    
    return avg_loss, accuracy, report

# Training loop
num_epochs = 10
best_acc = 0.0

train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in range(1, num_epochs + 1):
    # Training phase
    train_loss, train_acc = train_epoch(model, train_loader, criterion, 
                                      optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # Evaluation phase
    val_loss, val_acc, val_report = evaluate(model, test_loader, criterion, device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    # Learning rate scheduling
    scheduler.step(val_acc)
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'kan_best_model.pth')

    # Print epoch results
    print(f"\nEpoch {epoch}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Print detailed validation report every 5 epochs
    if epoch % 5 == 0:
        print("\nValidation Report:")
        print(val_report)

# Final evaluation
print("\nLoading best model for final evaluation...")
model.load_state_dict(torch.load('kan_best_model.pth'))
test_loss, test_acc, test_report = evaluate(model, test_loader, criterion, device)

print("\nFinal Test Results:")
print(f"Test Accuracy: {test_acc:.4f}")
print("\nDetailed Classification Report:")
print(test_report)

Training:  16%|█▌        | 5/32 [22:24<2:01:02, 268.97s/it]


KeyboardInterrupt: 

In [None]:
def count_parameters(model):
    """Count the total number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Count parameters
total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")

# Display parameter distribution by module
param_sizes = {}
for name, param in model.named_parameters():
    if param.requires_grad:
        module_name = name.split('.')[0]
        if module_name not in param_sizes:
            param_sizes[module_name] = 0
        param_sizes[module_name] += param.numel()

print("\nParameter distribution by module:")
for module_name, param_count in param_sizes.items():
    percentage = 100 * param_count / total_params
    print(f"{module_name}: {param_count:,} parameters ({percentage:.2f}%)")

In [None]:
import matplotlib.pyplot as plt

# Modify the training loop to store metrics
# train_losses = []
# train_accs = []
# val_losses = []
# val_accs = []

# for epoch in range(1, num_epochs + 1):
#     # Training phase
#     train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
#     train_losses.append(train_loss)
#     train_accs.append(train_acc)
    
#     # Evaluation phase
#     val_loss, val_acc, val_report = evaluate(model, test_loader, criterion, device)
#     val_losses.append(val_loss)
#     val_accs.append(val_acc)

# Plotting
plt.figure(figsize=(12, 5))

# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

# Plot Accuracy
plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), train_accs, 'b-', label='Training Accuracy')
plt.plot(range(1, num_epochs + 1), val_accs, 'r-', label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()