# Kernel Attention Network (KAN) for Brain Tumor Classification

Implementation of KAN architecture with proper feature extraction and attention mechanisms for MRI image analysis.

## 1. Import Dependencies

Import required libraries and modules for implementing KAN.

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

## 2. Data Loading and Preprocessing

Setup data pipelines with proper transforms for MRI images.

In [3]:
# Data directories
TRAIN_DIR = "/home/mhs/thesis/Brain MRI ND-5 Dataset/tumordata/Training"
TEST_DIR  = "/home/mhs/thesis/Brain MRI ND-5 Dataset/tumordata/Testing"

# Define image transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((224, 224)),                # Resize for backbone
    transforms.ToTensor(),                        # Convert to tensor
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Repeat grayscale to 3 channels
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])   # ImageNet normalization
])

# Create datasets
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
test_dataset  = datasets.ImageFolder(TEST_DIR,  transform=transform)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                         num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False,
                         num_workers=4, pin_memory=True)

num_classes = len(train_dataset.classes)
print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples:  {len(test_dataset)}")

Number of classes: 4
Training samples: 13927
Testing samples:  3961


## 3. Define KAN Architecture

Implement the Kernel Attention Network model with proper feature extraction.

In [4]:
class KernelAttention(nn.Module):
    def __init__(self, in_dim, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(in_dim, in_dim, kernel_size=kernel_size, 
                             padding=kernel_size//2, groups=in_dim)
        self.spatial_gate = nn.Sequential(
            nn.Conv2d(in_dim, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Local feature aggregation
        local_feat = self.conv(x)
        # Generate attention weights
        attn = self.spatial_gate(local_feat)
        return x * attn

class KANModel(nn.Module):
    def __init__(self, num_classes, backbone='resnet18'):
        super().__init__()
        
        # 1. Feature Extraction Backbone
        if backbone == 'resnet18':
            base = models.resnet18(pretrained=True)
            self.feature_dim = 512
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
            
        # Remove the final FC layer
        self.features = nn.Sequential(*list(base.children())[:-2])
        
        # 2. Kernel Attention Module
        self.attention = KernelAttention(self.feature_dim)
        
        # 3. Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # 4. Classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        # Extract features
        x = self.features(x)  # [B, 512, H', W']
        
        # Apply kernel attention
        x = self.attention(x)
        
        # Global average pooling
        x = self.gap(x)      # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 512]
        
        # Classification
        return self.classifier(x)

## 4. Training Configuration

Setup training parameters and optimization configuration.

In [7]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = KANModel(num_classes=num_classes).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Learning rate scheduler (removed verbose parameter)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.5,
    patience=5
)

Using device: cuda


## 5. Training and Evaluation Loop

Implement the main training loop with evaluation metrics.

In [8]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc='Training'):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Compute epoch metrics
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Compute metrics
    avg_loss = running_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, 
                                target_names=train_dataset.classes)
    
    return avg_loss, accuracy, report

# Training loop
num_epochs = 10
best_acc = 0.0

for epoch in range(1, num_epochs + 1):
    # Training phase
    train_loss, train_acc = train_epoch(model, train_loader, criterion, 
                                      optimizer, device)
    
    # Evaluation phase
    val_loss, val_acc, val_report = evaluate(model, test_loader, criterion, device)
    
    # Learning rate scheduling
    scheduler.step(val_acc)
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'kan_best_model.pth')
    
    # Print epoch results
    print(f"\nEpoch {epoch}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Print detailed validation report every 5 epochs
    if epoch % 5 == 0:
        print("\nValidation Report:")
        print(val_report)

# Final evaluation
print("\nLoading best model for final evaluation...")
model.load_state_dict(torch.load('kan_best_model.pth'))
test_loss, test_acc, test_report = evaluate(model, test_loader, criterion, device)

print("\nFinal Test Results:")
print(f"Test Accuracy: {test_acc:.4f}")
print("\nDetailed Classification Report:")
print(test_report)

Training: 100%|██████████| 436/436 [-1:59:30<00:00, -14.31it/s]
Evaluating: 100%|██████████| 124/124 [00:03<00:00, 37.38it/s]



Epoch 1/10
Train Loss: 0.2162, Train Acc: 0.9273
Val Loss: 0.1920, Val Acc: 0.9518


Training: 100%|██████████| 436/436 [00:15<00:00, 27.86it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 41.46it/s]



Epoch 2/10
Train Loss: 0.0448, Train Acc: 0.9850
Val Loss: 0.2281, Val Acc: 0.9609


Training: 100%|██████████| 436/436 [01:03<00:00,  6.88it/s]
Evaluating: 100%|██████████| 124/124 [-1:59:16<00:00, -2.78it/s]



Epoch 3/10
Train Loss: 0.0278, Train Acc: 0.9911
Val Loss: 0.1626, Val Acc: 0.9664


Training: 100%|██████████| 436/436 [00:15<00:00, 27.93it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 41.98it/s]



Epoch 4/10
Train Loss: 0.0270, Train Acc: 0.9913
Val Loss: 0.1982, Val Acc: 0.9626


Training: 100%|██████████| 436/436 [00:15<00:00, 28.27it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 51.97it/s]



Epoch 5/10
Train Loss: 0.0144, Train Acc: 0.9955
Val Loss: 0.1915, Val Acc: 0.9692

Validation Report:
                  precision    recall  f1-score   support

    glioma_tumor       1.00      0.93      0.96      1208
meningioma_tumor       0.93      0.99      0.96       930
        no_tumor       0.95      1.00      0.97       831
 pituitary_tumor       0.99      0.98      0.98       992

        accuracy                           0.97      3961
       macro avg       0.97      0.97      0.97      3961
    weighted avg       0.97      0.97      0.97      3961



Training: 100%|██████████| 436/436 [01:03<00:00,  6.90it/s]
Evaluating: 100%|██████████| 124/124 [-1:59:15<00:00, -2.74it/s]



Epoch 6/10
Train Loss: 0.0200, Train Acc: 0.9943
Val Loss: 0.2150, Val Acc: 0.9702


Training: 100%|██████████| 436/436 [00:15<00:00, 28.32it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 51.73it/s]



Epoch 7/10
Train Loss: 0.0160, Train Acc: 0.9950
Val Loss: 0.1916, Val Acc: 0.9674


Training: 100%|██████████| 436/436 [00:15<00:00, 28.04it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 51.78it/s]



Epoch 8/10
Train Loss: 0.0128, Train Acc: 0.9961
Val Loss: 0.2205, Val Acc: 0.9712


Training: 100%|██████████| 436/436 [00:15<00:00, 28.19it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 52.38it/s]



Epoch 9/10
Train Loss: 0.0082, Train Acc: 0.9972
Val Loss: 0.2349, Val Acc: 0.9571


Training: 100%|██████████| 436/436 [00:15<00:00, 28.15it/s]
Evaluating: 100%|██████████| 124/124 [00:02<00:00, 50.78it/s]



Epoch 10/10
Train Loss: 0.0166, Train Acc: 0.9950
Val Loss: 0.1996, Val Acc: 0.9702

Validation Report:
                  precision    recall  f1-score   support

    glioma_tumor       0.99      0.94      0.96      1208
meningioma_tumor       0.94      0.97      0.96       930
        no_tumor       0.96      1.00      0.98       831
 pituitary_tumor       0.98      0.99      0.98       992

        accuracy                           0.97      3961
       macro avg       0.97      0.97      0.97      3961
    weighted avg       0.97      0.97      0.97      3961


Loading best model for final evaluation...


Evaluating: 100%|██████████| 124/124 [00:02<00:00, 52.13it/s]


Final Test Results:
Test Accuracy: 0.9712

Detailed Classification Report:
                  precision    recall  f1-score   support

    glioma_tumor       0.99      0.93      0.96      1208
meningioma_tumor       0.94      0.99      0.96       930
        no_tumor       0.95      1.00      0.97       831
 pituitary_tumor       0.99      0.98      0.99       992

        accuracy                           0.97      3961
       macro avg       0.97      0.97      0.97      3961
    weighted avg       0.97      0.97      0.97      3961




