In [2]:
from sympy import symbols, cos, exp, diff

# Defining symbols
y, a = symbols('y a')

# Defining the function
f = cos(a*y) * exp(-y**2 / 2)

# Finding the derivative with respect to y
derivative_f = diff(f, y)

# Display the derivative
derivative_f

-a*exp(-y**2/2)*sin(a*y) - y*exp(-y**2/2)*cos(a*y)

In [2]:
'''This is a sample code for the simulations of the paper:
Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024)

https://arxiv.org/abs/2405.12832
and also available at:
https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325
We used efficient KAN notation and some part of the code:https://github.com/Blealtan/efficient-kan

'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import math

class KANLinear(nn.Module):
    def __init__(self, in_features, out_features, wavelet_type='mexican_hat'):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.wavelet_type = wavelet_type

        # Parameters for wavelet transformation
        self.scale = nn.Parameter(torch.ones(out_features, in_features))
        self.translation = nn.Parameter(torch.zeros(out_features, in_features))

        # Linear weights for combining outputs
        #self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features)) #not used; you may like to use it for wieghting base activation and adding it like Spl-KAN paper
        self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))

        nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))

        # Base activation function #not used for this experiment
        self.base_activation = nn.SiLU()

        # Batch normalization
        self.bn = nn.BatchNorm1d(out_features)

    def wavelet_transform(self, x):
        if x.dim() == 2:
            x_expanded = x.unsqueeze(1)
        else:
            x_expanded = x

        print(x_expanded.size())

        translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)
        print(f'{translation_expanded.size()=}')
        scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)
        print(f'{scale_expanded.size()=}')
        x_scaled = (x_expanded - translation_expanded) / scale_expanded
        print(f'{x_scaled.size()=}')

        # Implementation of different wavelet types
        if self.wavelet_type == 'mexican_hat':
            term1 = ((x_scaled ** 2)-1)
            term2 = torch.exp(-0.5 * x_scaled ** 2)
            print(f'{term1.size()=}, {term2.size()=}')
            wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            print(f'{wavelet.size()=}, {wavelet_weighted.size()=}')
            wavelet_output = wavelet_weighted.sum(dim=2)
            print(f'{wavelet_output.size()=}')
        elif self.wavelet_type == 'morlet':
            omega0 = 5.0  # Central frequency
            real = torch.cos(omega0 * x_scaled)
            envelope = torch.exp(-0.5 * x_scaled ** 2)
            wavelet = envelope * real
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
            
        elif self.wavelet_type == 'dog':
            # Implementing Derivative of Gaussian Wavelet 
            dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)
            wavelet = dog
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
        elif self.wavelet_type == 'meyer':
            # Implement Meyer Wavelet here
            # Constants for the Meyer wavelet transition boundaries
            v = torch.abs(x_scaled)
            pi = math.pi

            def meyer_aux(v):
                return torch.where(v <= 1/2,torch.ones_like(v),torch.where(v >= 1,torch.zeros_like(v),torch.cos(pi / 2 * nu(2 * v - 1))))

            def nu(t):
                return t**4 * (35 - 84*t + 70*t**2 - 20*t**3)
            # Meyer wavelet calculation using the auxiliary function
            wavelet = torch.sin(pi * v) * meyer_aux(v)
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
        elif self.wavelet_type == 'shannon':
            # Windowing the sinc function to limit its support
            pi = math.pi
            sinc = torch.sinc(x_scaled / pi)  # sinc(x) = sin(pi*x) / (pi*x)

            # Applying a Hamming window to limit the infinite support of the sinc function
            window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device)
            # Shannon wavelet is the product of the sinc function and the window
            wavelet = sinc * window
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
            #You can try many more wavelet types ...
        else:
            raise ValueError("Unsupported wavelet type")

        return wavelet_output

    def forward(self, x):
        wavelet_output = self.wavelet_transform(x)
        #You may like test the cases like Spl-KAN
        #wav_output = F.linear(wavelet_output, self.weight)
        #base_output = F.linear(self.base_activation(x), self.weight1)

        base_output = F.linear(x, self.weight1)
        combined_output =  wavelet_output #+ base_output 

        # Apply batch normalization
        return self.bn(combined_output)

class KAN(nn.Module):
    def __init__(self, layers_hidden, wavelet_type='mexican_hat'):
        super(KAN, self).__init__()
        self.layers = nn.ModuleList()
        for in_features, out_features in zip(layers_hidden[:-1], layers_hidden[1:]):
            self.layers.append(KANLinear(in_features, out_features, wavelet_type))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [3]:
x = torch.randn(1000, 128)
layer = KANLinear(128, 144)

layer(x)

torch.Size([1000, 1, 128])
translation_expanded.size()=torch.Size([1000, 144, 128])
scale_expanded.size()=torch.Size([1000, 144, 128])
x_scaled.size()=torch.Size([1000, 144, 128])
term1.size()=torch.Size([1000, 144, 128]), term2.size()=torch.Size([1000, 144, 128])
wavelet.size()=torch.Size([1000, 144, 128]), wavelet_weighted.size()=torch.Size([1000, 144, 128])
wavelet_output.size()=torch.Size([1000, 144])


tensor([[ 0.9303,  0.4346, -1.1224,  ..., -0.4664,  1.1976,  0.9476],
        [ 0.8187,  0.7688, -0.2611,  ..., -1.2191,  1.0781,  0.2735],
        [ 0.2438,  0.5340, -2.0622,  ..., -0.5690, -1.8118,  0.7056],
        ...,
        [-0.7016, -0.3629, -1.8911,  ..., -0.7746, -1.1137,  0.0870],
        [-2.0423,  0.6564,  1.9434,  ...,  0.3529,  1.7313, -0.1726],
        [ 0.0924,  1.4126,  0.6934,  ...,  0.3420, -0.5840, -0.7927]],
       grad_fn=<NativeBatchNormBackward0>)