In [2]:
%pip install torch-summary

Note: you may need to restart the kernel to use updated packages.


In [10]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from medmnist import INFO  # Contains dataset metadata
import medmnist          # This imports all available MedMNIST dataset classes
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import math
import matplotlib.pyplot as plt
from torchsummary import summary
import pandas as pd
import numpy as np
from torch.utils.data import Dataset

In [27]:
class KANLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, grid_size=5, spline_order=5, scale_noise=0.15, scale_base=1.0,
                 scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02,
                 grid_range=[-1, 1]):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, grid_size + spline_order) sized tensor, which is
        extremely slow and consumes a lot of memory.
        """
        weight = self.scaled_spline_weight
        activation = self.base_activation(self.base_weight)
        activation_regularization = weight.square().mean()
        entropy_regularization = (
            -weight.softmax(dim=-1) * weight.log_softmax(dim=-1)
        ).mean()
        regularization = (
            regularize_activation * activation_regularization
            + regularize_entropy * entropy_regularization
        )
        return regularization


In [28]:
import numpy as np
from torch.utils.data import Dataset
from PIL import Image  # Import PIL's Image module

class MyMedNISTDataset(Dataset):
    def __init__(self, npz_path, split='train', transform=None):
        super().__init__()
        self.split = split
        self.transform = transform
        
        # Load data from .npz using NumPy
        data_dict = np.load(npz_path)
        
        # Based on the chosen split, pick the corresponding arrays
        if split == 'train':
            self.images = data_dict['train_images']
            self.labels = data_dict['train_labels']
        elif split == 'val':
            self.images = data_dict['val_images']
            self.labels = data_dict['val_labels']
        elif split == 'test':
            self.images = data_dict['test_images']
            self.labels = data_dict['test_labels']
        else:
            raise ValueError(f"Unknown split: {split} (expected 'train', 'val', or 'test')")
        
        # Convert labels to a 1D array if necessary
        self.labels = np.squeeze(self.labels)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        # Convert image to float32 if needed
        image = image.astype(np.float32)
        
        # Convert NumPy array to PIL Image.
        # Note: Depending on your data's value range, you might need to adjust this.
        # For example, if your data is already in [0, 255] you can convert directly.
        # If it's in [0, 1], consider scaling by 255 first.
        image = Image.fromarray(np.uint8(image))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


In [29]:
def get_transforms(n_channels,mean,std):
    # If not 3 channels, first convert to 3 channels; otherwise, use identity.
    convert = transforms.Grayscale(num_output_channels=3) if n_channels != 3 else lambda x: x

    train_transform = transforms.Compose([
        convert,
        transforms.RandomResizedCrop(size=(28, 28), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.RandomRotation(degrees=(-15, 15)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    test_transform = transforms.Compose([
        convert,
        transforms.ToTensor(),
        transforms.Normalize(mean=mean,
                             std=std)
    ])
    val_transform = test_transform

    return train_transform, val_transform, test_transform

# EXAMPLE USAGE
# for key in dataset_keys:
#     info = INFO[key]
#     n_channels = info.get('n_channels', 1)
#     npz_path = f"./data/{key}.npz"
#     train_dataset = MyMedNISTDataset(
#         npz_path=npz_path,
#         split='train',
#         transform=get_transforms(n_channels)[0]
#     )
#     print(f"Train {key} size:", len(train_dataset))

# Create a DataLoader to batch and shuffle
#train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [30]:
class CNNKAN(nn.Module):
    def __init__(self, num_classes):
        super(CNNKAN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = KANLinear(64 * 7 * 7, 128)
        self.fc2 = KANLinear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [31]:
#dummy input for CNNKAN model
dummy_input = torch.randn(1, 3, 28, 28)
model = CNNKAN(10)
model(dummy_input)


tensor([[-0.0236, -0.0359,  0.0215, -0.0142,  0.0044, -0.0764, -0.0408,  0.0057,
          0.0014,  0.0257]], grad_fn=<AddBackward0>)

In [32]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 100 == 99:
            print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {running_loss / 100:.6f}')
            running_loss = 0.0

def evaluate(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.0f}%)\n')

In [33]:
task_classes = {'breastmnist': 2,
 'dermamnist': 7,
 'octmnist': 4,
 'organamnist': 11,
 'organcmnist': 11,
 'organsmnist': 11,
 'pathmnist': 9,
 'pneumoniamnist': 2,
 'retinamnist': 5,
 'tissuemnist': 8,
 'bloodmnist': 8}

# Set device.
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

dataset_keys = [
    #'breastmnist',
    #'dermamnist',
    #'octmnist',
    #'organcmnist',
    'organsmnist',
    #'pathmnist',
    #'pneumoniamnist',
    #'retinamnist',
    #'tissuemnist',
    #'bloodmnist',
    #'organamnist'
]



Using device: mps


In [34]:
batch_size = 64
epochs = 10
lr = 0.001
momentum = 0.9

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

In [35]:
def compute_mean_std(dataset, batch_size=64):
    """Compute per-channel mean and standard deviation over a dataset."""
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    mean = torch.zeros(dataset[0][0].size(0))
    std = torch.zeros(dataset[0][0].size(0))
    nb_samples = 0

    for data, _ in loader:
        batch_samples = data.size(0)
        # reshape to (batch_size, channels, H*W)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples
    return mean, std

In [38]:
import gc
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader

for key in dataset_keys:
    print(f"======== Training on dataset: {key} ========")
    npz_path = f"./data/{key}.npz"
    info = INFO[key]
    n_channels = info.get('n_channels', 1)
    
    # Step 1: Compute statistics from the training set
    base_transform = transforms.ToTensor()
    train_dataset_for_stats = MyMedNISTDataset(npz_path=npz_path, split='train', transform=base_transform)
    computed_mean, computed_std = compute_mean_std(train_dataset_for_stats)
    print(f"Computed mean from data: {computed_mean}, std: {computed_std} for dataset {key}")
    
    # OPTIONAL: Visualize a sample image to check scaling and channel count
    sample_img, sample_label = train_dataset_for_stats[0]
    print(f"Sample image shape: {sample_img.shape}, min: {sample_img.min()}, max: {sample_img.max()}")
    # Uncomment the following lines to display the image (if running locally)
    # plt.imshow(sample_img.squeeze(), cmap='gray' if n_channels == 1 else None)
    # plt.title(f"Label: {sample_label}")
    # plt.show()
    
    # Step 2: Decide whether to use computed values or fixed ones.
    # Note: The following lines overwrite computed values—remove/comment these if you want to use computed_mean/std.
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    # If your model expects single-channel but you want to replicate values for a 3-channel model,
    # you might consider:
    # if n_channels == 1:
    #     mean = computed_mean.tolist() * 3
    #     std = computed_std.tolist() * 3
    print(f"Using mean: {mean}, std: {std} for normalization on dataset {key}")
    
    # Step 3: Get transforms using your (computed or fixed) normalization values.
    train_transform, val_transform, test_transform = get_transforms(n_channels, mean, std)
    
    # Step 4: Create dataset splits
    train_dataset = MyMedNISTDataset(npz_path=npz_path, split='train', transform=train_transform)
    val_dataset   = MyMedNISTDataset(npz_path=npz_path, split='val', transform=val_transform)
    test_dataset  = MyMedNISTDataset(npz_path=npz_path, split='test', transform=test_transform)
    
    # Step 5: Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    # Optional: Overfit on a small subset (e.g., first 100 samples) to check if model can learn
    # small_subset_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(100)), batch_size=10, shuffle=True)
    
    # Step 6: Create model (and reinitialize optimizer if needed)
    model = CNNKAN(task_classes[key]).to(device)
    # It might be useful to reinitialize the optimizer for each new dataset/model:
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
    for epoch in range(1, epochs + 1):
        print(f"--- Epoch {epoch} ---")
        train(model, device, train_loader, optimizer, epoch)
        evaluate(model, device, test_loader)
    
    print(f"======== Finished training on dataset: {key} ========")
    
    # Clean up GPU memory after each dataset training cycle
    del model, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader
    gc.collect()
    torch.cuda.empty_cache()


Computed mean from data: tensor([0.4950]), std: tensor([0.2291]) for dataset organsmnist
Sample image shape: torch.Size([1, 28, 28]), min: 0.0, max: 1.0
Using mean: [0.5, 0.5, 0.5], std: [0.5, 0.5, 0.5] for normalization on dataset organsmnist
--- Epoch 1 ---

Test set: Average loss: 2.3989, Accuracy: 784/8829 (9%)

--- Epoch 2 ---


KeyboardInterrupt: 