In [1]:
import torch
from pykan.kan.spline import curve2coef, coef2curve, B_batch
from pykan.kan.KAN import KAN
from torch import nn

import numpy as np
from tqdm import tqdm

In [14]:
class LAN_layer_2D(nn.Module):
    def __init__(
        self,
        dim=2,
        num=5,
        k=3,
        noise_scale=0.1,
        scale_base=1.0,
        scale_sp=1.0,
        base_fun=torch.nn.SiLU(),
        grid_eps=0.02,
        grid_range=[-1,1],
        sp_trainable=True,
        sb_trainable=True,
        device='cpu'
    ):
        super().__init__()

        self.dim = dim
        self.num = num
        self.k = k 
        self.base_fun = base_fun
        self.device = device

        self.scale_base = scale_base
        self.scale_sp = scale_sp

        # определить grid
        self.grid = torch.einsum('i,j->ij', torch.ones(self.dim, device=device), torch.linspace(grid_range[0], grid_range[1], steps=num + 1, device=device)) # (dim, grid)
        self.grid = torch.nn.Parameter(self.grid).requires_grad_(False)

        noises = (torch.rand(self.dim, self.grid.shape[1]) - 1 / 2) * noise_scale / num
        noises = noises.to(device)
        # shape: (size, coef)
        self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k, device))

        self.bias = nn.Linear(dim, 1, bias=False, device=device)
        self.bias.weight.data *= 0.
        self.acts_scale = torch.zeros(dim).to(self.device)

    def forward(self, x):
        # (batch, channels, h, w) -> (batch * h * w, channels)
        x = x.reshape(-1, x.shape[1]) 

        # (batch, dim) -> (dim, batch)
        x = x.permute(1, 0)
        y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k, device=self.device).permute(1, 0)  # shape (batch, dim)
        base = self.base_fun(x).permute(1, 0)  # shape (batch, dim)
        y = self.scale_base * base + self.scale_sp * y # shape (batch, channels, h, w)
        y = y + self.bias.weight

        return y.reshape(x.shape)
    
    # def update_grid_from_samples(self, x):
    #     batch = x.shape[0]
    #     x = torch.einsum('ij,k->ikj', x, torch.ones(self.dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
    #     x_pos = torch.sort(x, dim=1)[0]
    #     y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device)
    #     num_interval = self.grid.shape[1] - 1
    #     ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
    #     grid_adaptive = x_pos[:, ids]
    #     margin = 0.01
    #     grid_uniform = torch.cat([grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in np.linspace(0, 1, num=self.grid.shape[1])], dim=1)
    #     self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
    #     self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device)

In [15]:
device = 'cuda'

In [26]:
import torchvision.models as models
model = models.resnet18(pretrained=True)

In [27]:
for param in model.parameters():
    param.requires_grad = False

In [28]:
model.relu = LAN_layer_2D(64, device=device)

model.layer1[0].relu = LAN_layer_2D(64, device=device)
model.layer1[1].relu = LAN_layer_2D(64, device=device)

model.layer2[0].relu = LAN_layer_2D(128, device=device)
model.layer2[1].relu = LAN_layer_2D(128, device=device)

model.layer3[0].relu = LAN_layer_2D(256, device=device)
model.layer3[1].relu = LAN_layer_2D(256, device=device)

model.layer4[0].relu = LAN_layer_2D(512, device=device)
model.layer4[1].relu = LAN_layer_2D(512, device=device)

In [29]:
model = model.to(device)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/85.4k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/29.1G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/29.3G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/29.0G [00:00<?, ?B/s]