In [18]:
#@title Installs
#!pip install torch
#!pip install timm
#!pip install tqdm
#!pip install numpy

In [1]:
!nvidia-smi

Thu Apr  3 02:12:13 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce GTX 1060 6GB    Off | 00000000:09:00.0  On |                  N/A |
| 25%   50C    P5              12W / 180W |    597MiB /  6144MiB |     27%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
#@title Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import dataset


import argparse
import timm
from tqdm import tqdm

import json

  from .autonotebook import tqdm as notebook_tqdm


# Components


In [3]:
#@title ReluKANLayer
class ReluKANLayer(nn.Module):

    def __init__(self, input_size: int, output_size: int, g: int, k: int,  is_train: bool = False):
        super().__init__()

        self.g, self.k, self.r = g, k, 4*g*g / ((k+1)*(k+1))
        self.input_size, self.output_size = input_size, output_size
        phase_low = np.arange(-k, g) / g
        phase_height = phase_low + (k+1) / g
        self.phase_low = nn.Parameter(torch.Tensor(np.array([phase_low for i in range(input_size)])), requires_grad=is_train)
        self.phase_height = nn.Parameter(torch.Tensor(np.array([phase_height for i in range(input_size)])), requires_grad=is_train)
        self.equal_size_conv = nn.Conv2d(1, output_size, (g+k, input_size))

    def forward(self, x):
        #x_expanded = x.unsqueeze(2).expand(-1, -1, self.phase_low.size(1))
        #x1 = torch.relu(x_expanded - self.phase_low)
        #x2 = torch.relu(self.phase_height - x_expanded)

        # x: (batch_size, input_size)
        x_unsqueezed = x.unsqueeze(2)  # Now (batch_size, input_size, 1)
        # Unsqueeze parameters to (1, input_size, phase_size) so they broadcast correctly.
        phase_low = self.phase_low.unsqueeze(0)
        phase_height = self.phase_height.unsqueeze(0)

        x1 = torch.relu(x_unsqueezed - phase_low)
        x2 = torch.relu(phase_height - x_unsqueezed)

        x = x1 * x2 * self.r
        x = x * x
        x = x.reshape((len(x), 1, self.g + self.k, self.input_size))
        x = self.equal_size_conv(x)
        x = x.reshape((len(x), self.output_size, 1))
        return x

In [4]:
#@title ReluKANOperator2d
class ReluKANOperator2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 g=None, k=None, kan_module_constructor=None):
        """\
        Parameters:
          - in_channels: Number of channels in the input.
          - out_channels: Number of output channels (each will have its own KAN module).
          - kernel_size: Kernel size (int or tuple).
          - stride, padding, dilation: Convolution parameters.
          - groups: Not used in this basic implementation.
          - bias: Whether to add a learnable bias.
          - g, k: Parameters for ReluKANLayer (if using default constructor).
          - kan_module_constructor: Optional callable that accepts the flattened patch size and returns a KAN module.
        """
        super(ReluKANOperator2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Ensure kernel_size, stride, padding, dilation are tuples.
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.padding = padding if isinstance(padding, tuple) else (padding, padding)
        self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
        self.groups = groups

        # The flattened patch size: in_channels * kernel_height * kernel_width.
        self.patch_size = in_channels * self.kernel_size[0] * self.kernel_size[1]

        # Use the provided kan_module_constructor or default to one that uses ReluKANLayer.
        if kan_module_constructor is None:
            if g is None or k is None:
                raise ValueError("Provide g and k parameters for the default ReluKANLayer constructor")
            def default_kan_module_constructor(in_features):
                # Each KAN module converts a flattened patch to a scalar (output_size=1).
                return ReluKANLayer(input_size=in_features, output_size=1, g=g, k=k)
            kan_module_constructor = default_kan_module_constructor

        # Create one KAN module per output channel.
        self.kan_modules = nn.ModuleList(
            [kan_module_constructor(self.patch_size) for _ in range(out_channels)]
        )

        # Optional bias per output channel.
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

    def forward(self, x):
        # x: (B, in_channels, H, W)
        B, C, H, W = x.shape

        # Extract sliding patches; shape: (B, patch_size, L) where L is the number of patches.
        patches = F.unfold(x, kernel_size=self.kernel_size, dilation=self.dilation,
                           padding=self.padding, stride=self.stride)
        # Rearrange to (B, L, patch_size)
        patches = patches.transpose(1, 2)
        B, L, patch_size = patches.shape

        # Flatten the patches to shape (B*L, patch_size) for processing.
        patches_reshaped = patches.reshape(B * L, patch_size)

        outputs = []
        for kan in self.kan_modules:
            # Each KAN module processes the flattened patches.
            # Expected output shape from ReluKANLayer: (B*L, 1, 1)
            out = kan(patches_reshaped)
            # Reshape to (B, L)
            out = out.view(B, L)
            outputs.append(out)

        # Stack along a new channel dimension: (B, out_channels, L)
        out_tensor = torch.stack(outputs, dim=1)

        # Calculate output spatial dimensions.
        H_out = (H + 2*self.padding[0] - self.dilation[0]*(self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
        W_out = (W + 2*self.padding[1] - self.dilation[1]*(self.kernel_size[1] - 1) - 1) // self.stride[1] + 1

        # Reshape to (B, out_channels, H_out, W_out)
        out_tensor = out_tensor.view(B, self.out_channels, H_out, W_out)

        if self.bias is not None:
            out_tensor = out_tensor + self.bias.view(1, -1, 1, 1)

        return out_tensor

In [5]:
#@title ReluKANBlock
class ReluKANBlock(nn.Module):

    def __init__(self, in_channels, out_channels, g, k, stride=1):
        super(ReluKANBlock, self).__init__()
        self.layer = ReluKANOperator2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, g=g, k=k)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.layer(x)
        x = self.bn(x)
        return x

In [6]:
#@title ReluKANNetB0
class ReluKANNetB0(nn.Module):

    def __init__(self,
                 in_channels,
                 num_classes,
                 g=3,
                 k=3,
                 depth_list=[1, 1, 1, 1],
                 channel_configs=[8, 12, 16, 20],
                 base_channels=32,
                 ):
        super(ReluKANNetB0, self).__init__()

        self.depth_list = depth_list
        self.stem = ReluKANBlock(in_channels, base_channels, g, k, stride=1)

        self.stages = nn.ModuleList()
        input_channels = base_channels
        #channel_configs = [8, 12, 16, 20]
        for stage, repeats in enumerate(depth_list):
            stage_layers = []
            for i in range(repeats):
                stride = 2 if i == 0 and stage != 0 else 1
                stage_layers.append(ReluKANBlock(input_channels, channel_configs[stage], g, k, stride=stride))
                input_channels = channel_configs[stage]
            self.stages.append(nn.Sequential(*stage_layers))

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(input_channels, num_classes)

    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

# Datasets

In [7]:
#@title cifar10
def cifar10():
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.CIFAR10(root="./data/cifar10/train", train=True, download=True, transform=transform)
    val_dataset = torchvision.datasets.CIFAR10(root="./data/cifar10/test", train=False, download=True, transform=transform)
    in_channels = 3
    num_classes = 10
    return train_dataset, val_dataset, in_channels, num_classes

In [8]:
#@title cifar100
def cifar100():
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.CIFAR100(root="./data/cifar100/train", train=True, download=True, transform=transform)
    val_dataset = torchvision.datasets.CIFAR100(root="./data/cifar100/test", train=False, download=True, transform=transform)
    in_channels = 3
    num_classes = 100
    return train_dataset, val_dataset, in_channels, num_classes

In [9]:
#@title stanford_cars
def stanford_cars():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.ImageFolder(root="./data/stanford_cars/train", transform=transform)
    val_dataset = torchvision.datasets.ImageFolder(root="./data/stanford_cars/test", transform=transform)
    in_channels = 3
    num_classes = 196
    return train_dataset, val_dataset, in_channels, num_classes

In [10]:
#@title food101
def food101():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.ImageFolder(root="./data/food101/train", transform=transform)
    val_dataset = torchvision.datasets.ImageFolder(root="./data/food101/test", transform=transform)
    in_channels = 3
    num_classes = 101
    return train_dataset, val_dataset, in_channels, num_classes

In [11]:
#@title oxford_iiit_pet
def oxford_iiit_pet():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.ImageFolder(root="./data/oxford_iiit_pet/train", transform=transform)
    val_dataset = torchvision.datasets.ImageFolder(root="./data/oxford_iiit_pet/test", transform=transform)
    in_channels = 3
    num_classes = 37
    return train_dataset, val_dataset, in_channels, num_classes

In [12]:
def get_datasets(dataset_name, batch_size):
    """Return the training and validation datasets along with input channels and number of classes."""
    datasets = {
        "cifar10": cifar10,
        "cifar100": cifar100,
        "stanford_cars": stanford_cars,
        "food101": food101,
        "oxford_iiit_pet": oxford_iiit_pet
    }
    if dataset_name in datasets:
        train_dataset, val_dataset, in_channels, num_classes = datasets[dataset_name]()
    else:
        raise ValueError(f"Dataset '{dataset_name}' is not supported.")

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_loader, val_loader, in_channels, num_classes

# Training

In [13]:
#@title Train epoch

def train_epoch(model, train_loader, optimizer, criterion, device, epoch, total_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(train_loader, desc=f"Training Epoch {epoch}/{total_epochs}", leave=False):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(train_loader.dataset)

In [14]:
#@title Validate

def validate(model, val_loader, criterion, device, epoch, total_epochs):
    model.eval()
    running_loss = 0.0
    correct = 0
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc=f"Validation Epoch {epoch}/{total_epochs}", leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == targets).sum().item()
    avg_loss = running_loss / len(val_loader.dataset)
    accuracy = correct / len(val_loader.dataset)
    return avg_loss, accuracy

# Runing

In [15]:
#@title Main function
def main():
    train_loader, val_loader, in_channels, num_classes = get_datasets(dataset_arg, batch_size_arg)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on {device}")
    results = {}

    # Map model names to their constructors with appropriate in_channels and num_classes.
    model_dict = {
        "relukannet": lambda: ReluKANNetB0(in_channels=in_channels, num_classes=num_classes, g=g_arg, k=k_arg, depth_list=depth_list_arg, channel_configs=channel_configs_arg),
        "mobilenetv2_100": lambda: timm.create_model("mobilenetv2_100", pretrained=False, num_classes=num_classes),
        "efficientnet_lite0": lambda: timm.create_model("efficientnet_lite0", pretrained=False, num_classes=num_classes),
        "resnet18": lambda: timm.create_model("resnet18", pretrained=False, num_classes=num_classes),
        "resnet34": lambda: timm.create_model("resnet34", pretrained=False, num_classes=num_classes),
        #"resnet50": lambda: timm.create_model("resnet50", pretrained=False, num_classes=num_classes),
        "vit_tiny_patch16_224": lambda: timm.create_model("vit_tiny_patch16_224", pretrained=False, num_classes=num_classes),
        #"vit_base_patch16_224": lambda: timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_classes),
        #"convnext_base": lambda: timm.create_model("convnext_base", pretrained=False, num_classes=num_classes),
    }

    models_arg=[model_arg] # To repalce args

    for model_name in models_arg:
        if model_name not in model_dict:
            print(f"Model '{model_name}' not recognized. Skipping.")
            continue

        print(f"\nTraining model: {model_name}")
        model = model_dict[model_name]()
        model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr_arg)
        criterion = nn.CrossEntropyLoss()

        # To track training progress.
        train_losses = []
        val_losses = []
        val_accuracies = []

        for epoch in range(1, epochs_arg + 1):
            train_loss = train_epoch(model, train_loader, optimizer, criterion, device, epoch, epochs_arg)
            val_loss, val_accuracy = validate(model, val_loader, criterion, device, epoch, epochs_arg)
            print(
                f"Epoch {epoch}/{epochs_arg} - Train Loss: {train_loss:.4f} | "
                f"Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}"
            )
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)


        results[model_name] = {
            "metrics": {
                "train_loss": train_losses,
                "val_loss": val_losses,
                "val_accuracy": val_accuracies,
            },
            "number_of_parameters": sum(p.numel() for p in model.parameters()),
        }

        with open(f'./results/{output_arg}.json', 'w') as f:
            json.dump(results, f, indent=4)
        print(f"\nTraining results saved to {output_arg}")

In [16]:
def print_model_param_count():
    # Map model names to their constructors with appropriate in_channels and num_classes.
    model_dict = {
        "relukannet": lambda: ReluKANNetB0(in_channels=in_channels, num_classes=num_classes, g=g_arg, k=k_arg, depth_list=depth_list_arg, channel_configs=channel_configs_arg),
        "mobilenetv2_100": lambda: timm.create_model("mobilenetv2_100", pretrained=False, num_classes=num_classes),
        "efficientnet_lite0": lambda: timm.create_model("efficientnet_lite0", pretrained=False, num_classes=num_classes),
        "resnet18": lambda: timm.create_model("resnet18", pretrained=False, num_classes=num_classes),
        "resnet34": lambda: timm.create_model("resnet34", pretrained=False, num_classes=num_classes),
        "vit_tiny_patch16_224": lambda: timm.create_model("vit_tiny_patch16_224", pretrained=False, num_classes=num_classes),
    }

    if model_arg not in model_dict:
        print(f"Model '{model_arg}' not recognized.")
        return

    # Get dataset details
    _, _, in_channels, num_classes = get_datasets(dataset_arg, batch_size_arg)

    # Initialize the model
    model = model_dict[model_arg]()
    
    # Calculate and print the number of parameters
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Model '{model_arg}' has {param_count} parameters.")

In [17]:
#@title Args
dataset_arg="cifar10" #@param ["cifar10", "cifar100", "stanford_cars", "food101", "oxford_iiit_pet"]
model_arg = "relukannet" #@param ["relukannet","mobilenetv2_100","efficientnet_lite0","resnet18","resnet34","vit_tiny_patch16_224"]

batch_size_arg = 36 # @param {"type":"slider","min":1,"max":64,"step":1}
epochs_arg = 100 #@param
lr_arg = 0.001 #@param

output_arg = "kannet_t1" # @param {"type":"string"}

g_arg = 3 #@param
k_arg = 3 #@param

depth_list_arg = [2, 2, 1, 1] #@param
channel_configs_arg = [4, 8, 10, 12] #@param

#main()
print_model_param_count()

Model 'relukannet' has 29386 parameters.
