In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

import torchvision as tv

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

from torch.cuda.amp import autocast, GradScaler
import plotly.graph_objects as go
import seaborn as sns
import plotly.express as px
import pandas as pd
from functools import partial
from typing import Callable, List, Optional, Type, Union
from torch import Tensor
from torch import flatten

In [53]:
!git clone https://github.com/IvanDrokin/torch-conv-kan.git
%cd torch-conv-kan
from kan_convs import KALNConv2DLayer, KANConv2DLayer, KACNConv2DLayer, FastKANConv2DLayer, KAGNConv2DLayer
from kans import KAN, KALN, KACN, KAGN, FastKAN
%cd models
from model_utils import kan_conv3x3, kagn_conv3x3, kacn_conv3x3, kaln_conv3x3, fast_kan_conv3x3, moe_kaln_conv3x3
from model_utils import kan_conv1x1, kagn_conv1x1, kacn_conv1x1, kaln_conv1x1, fast_kan_conv1x1

c:\ITMO_CAI\2023-2024\Wine_Aroma\torch-conv-kan\torch-conv-kan\models\torch-conv-kan\models\torch-conv-kan
c:\ITMO_CAI\2023-2024\Wine_Aroma\torch-conv-kan\torch-conv-kan\models\torch-conv-kan\models\torch-conv-kan\models


Cloning into 'torch-conv-kan'...

using dhist requires you to install the `pickleshare` library.



In [54]:
class BasicBlockTemplate(nn.Module):
    expansion: int = 1

    def __init__(
            self,
            conv1x1x1_fun,
            conv3x3x3_fun,
            inplanes: int,
            planes: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            groups: int = 1,
            base_width: int = 64,
            dilation: int = 1,
    ) -> None:
        super().__init__()
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3x3_fun(inplanes, planes, stride=stride, groups=groups)
        self.conv2 = conv1x1x1_fun(planes, planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)

        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity

        return out


class KANBasicBlock(BasicBlockTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 spline_order: int = 3,
                 grid_size: int = 5, base_activation: Optional[Callable[..., nn.Module]] = nn.GELU,
                 grid_range: List = [-1, 1],
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs
                 ):
        conv1x1x1_fun = partial(kan_conv1x1, spline_order=spline_order, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range,
                                dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kan_conv3x3, spline_order=spline_order, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range,
                                dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KANBasicBlock, self).__init__(conv1x1x1_fun,
                                            conv3x3x3_fun,
                                            inplanes=inplanes,
                                            planes=planes,
                                            stride=stride,
                                            downsample=downsample,
                                            groups=groups,
                                            base_width=base_width,
                                            dilation=dilation)


class FastKANBasicBlock(BasicBlockTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 grid_size: int = 5,
                 base_activation: Optional[Callable[..., nn.Module]] = nn.SiLU,
                 grid_range: List = [-1, 1],
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(fast_kan_conv1x1, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range,
                                l1_decay=l1_decay, dropout=dropout, **norm_kwargs)
        conv3x3x3_fun = partial(fast_kan_conv3x3, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range,
                                l1_decay=l1_decay, dropout=dropout, **norm_kwargs)

        super(FastKANBasicBlock, self).__init__(conv1x1x1_fun,
                                                conv3x3x3_fun,
                                                inplanes=inplanes,
                                                planes=planes,
                                                stride=stride,
                                                downsample=downsample,
                                                groups=groups,
                                                base_width=base_width,
                                                dilation=dilation)


class KALNBasicBlock(BasicBlockTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kaln_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kaln_conv3x3, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KALNBasicBlock, self).__init__(conv1x1x1_fun,
                                             conv3x3x3_fun,
                                             inplanes=inplanes,
                                             planes=planes,
                                             stride=stride,
                                             downsample=downsample,
                                             groups=groups,
                                             base_width=base_width,
                                             dilation=dilation)


class KAGNBasicBlock(BasicBlockTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kagn_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kagn_conv3x3, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KAGNBasicBlock, self).__init__(conv1x1x1_fun,
                                             conv3x3x3_fun,
                                             inplanes=inplanes,
                                             planes=planes,
                                             stride=stride,
                                             downsample=downsample,
                                             groups=groups,
                                             base_width=base_width,
                                             dilation=dilation)


class KACNBasicBlock(BasicBlockTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kacn_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kacn_conv3x3, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KACNBasicBlock, self).__init__(conv1x1x1_fun,
                                             conv3x3x3_fun,
                                             inplanes=inplanes,
                                             planes=planes,
                                             stride=stride,
                                             downsample=downsample,
                                             groups=groups,
                                             base_width=base_width,
                                             dilation=dilation)


class BottleneckTemplate(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
            self,
            conv1x1x1_fun,
            conv3x3x3_fun,
            inplanes: int,
            planes: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            groups: int = 1,
            base_width: int = 64,
            dilation: int = 1,
    ) -> None:
        super().__init__()
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1x1_fun(inplanes, width)
        # self.bn1 = norm_layer(width)
        self.conv2 = conv3x3x3_fun(width, width, stride=stride, groups=groups, dilation=dilation)
        # self.bn2 = norm_layer(width)
        self.conv3 = conv1x1x1_fun(width, planes * self.expansion)
        # self.bn3 = norm_layer(planes * self.expansion)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.conv3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        return out


class KANBottleneck(BottleneckTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 spline_order: int = 3,
                 grid_size: int = 5, base_activation: Optional[Callable[..., nn.Module]] = nn.GELU,
                 grid_range: List = [-1, 1],
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs
                 ):
        conv1x1x1_fun = partial(kan_conv1x1, spline_order=spline_order, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range, dropout=dropout,
                                l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kan_conv3x3, spline_order=spline_order, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range, dropout=dropout,
                                l1_decay=l1_decay, **norm_kwargs)

        super(KANBottleneck, self).__init__(conv1x1x1_fun,
                                            conv3x3x3_fun,
                                            inplanes=inplanes,
                                            planes=planes,
                                            stride=stride,
                                            downsample=downsample,
                                            groups=groups,
                                            base_width=base_width,
                                            dilation=dilation)


class FastKANBottleneck(BottleneckTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 grid_size: int = 5,
                 base_activation: Optional[Callable[..., nn.Module]] = nn.SiLU,
                 grid_range: List = [-1, 1],
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(fast_kan_conv1x1, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range, dropout=dropout,
                                l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(fast_kan_conv3x3, grid_size=grid_size,
                                base_activation=base_activation, grid_range=grid_range, dropout=dropout,
                                l1_decay=l1_decay, **norm_kwargs)

        super(FastKANBottleneck, self).__init__(conv1x1x1_fun,
                                                conv3x3x3_fun,
                                                inplanes=inplanes,
                                                planes=planes,
                                                stride=stride,
                                                downsample=downsample,
                                                groups=groups,
                                                base_width=base_width,
                                                dilation=dilation)


class KALNBottleneck(BottleneckTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kaln_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kaln_conv3x3, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KALNBottleneck, self).__init__(conv1x1x1_fun,
                                             conv3x3x3_fun,
                                             inplanes=inplanes,
                                             planes=planes,
                                             stride=stride,
                                             downsample=downsample,
                                             groups=groups,
                                             base_width=base_width,
                                             dilation=dilation)


class KAGNBottleneck(BottleneckTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kagn_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kagn_conv3x3, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KAGNBottleneck, self).__init__(conv1x1x1_fun,
                                             conv3x3x3_fun,
                                             inplanes=inplanes,
                                             planes=planes,
                                             stride=stride,
                                             downsample=downsample,
                                             groups=groups,
                                             base_width=base_width,
                                             dilation=dilation)


class MoEKALNBottleneck(BottleneckTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 num_experts: int = 8,
                 noisy_gating: bool = True,
                 k: int = 2,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs
                 ):
        conv1x1x1_fun = partial(kaln_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(moe_kaln_conv3x3, degree=degree, num_experts=num_experts,
                                k=k, noisy_gating=noisy_gating, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(MoEKALNBottleneck, self).__init__(conv1x1x1_fun,
                                                conv3x3x3_fun,
                                                inplanes=inplanes,
                                                planes=planes,
                                                stride=stride,
                                                downsample=downsample,
                                                groups=groups,
                                                base_width=base_width,
                                                dilation=dilation)

    def forward(self, x: Tensor, train: bool = True) -> Tensor:
        identity = x

        out = self.conv1(x)

        out, moe_loss = self.conv2(out, train=train)

        out = self.conv3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity

        return out, moe_loss


class MoEKALNBasicBlock(BasicBlockTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 num_experts: int = 8,
                 noisy_gating: bool = True,
                 k: int = 2,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kaln_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(moe_kaln_conv3x3, degree=degree, num_experts=num_experts,
                                k=k, noisy_gating=noisy_gating, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(MoEKALNBasicBlock, self).__init__(conv1x1x1_fun,
                                                conv3x3x3_fun,
                                                inplanes=inplanes,
                                                planes=planes,
                                                stride=stride,
                                                downsample=downsample,
                                                groups=groups,
                                                base_width=base_width,
                                                dilation=dilation)

    def forward(self, x: Tensor, train: bool = True) -> Tensor:
        identity = x

        out, moe_loss = self.conv1(x, train=train)

        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity

        return out, moe_loss


class KACNBottleneck(BottleneckTemplate):
    def __init__(self,
                 inplanes: int,
                 planes: int,
                 degree: int = 3,
                 stride: int = 1,
                 downsample: Optional[nn.Module] = None,
                 groups: int = 1,
                 base_width: int = 64,
                 dilation: int = 1,
                 dropout: float = 0.0,
                 l1_decay: float = 0.0,
                 **norm_kwargs):
        conv1x1x1_fun = partial(kacn_conv1x1, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)
        conv3x3x3_fun = partial(kacn_conv3x3, degree=degree, dropout=dropout, l1_decay=l1_decay, **norm_kwargs)

        super(KACNBottleneck, self).__init__(conv1x1x1_fun,
                                             conv3x3x3_fun,
                                             inplanes=inplanes,
                                             planes=planes,
                                             stride=stride,
                                             downsample=downsample,
                                             groups=groups,
                                             base_width=base_width,
                                             dilation=dilation)


class ResKANet(nn.Module):
    def __init__(
            self,
            block: Type[Union[KANBasicBlock, FastKANBasicBlock, KALNBasicBlock, KACNBasicBlock, KAGNBasicBlock,
                              KANBottleneck, FastKANBottleneck, KALNBottleneck, KACNBottleneck, KAGNBottleneck]],
            layers: List[int],
            input_channels: int = 3,
            use_first_maxpool: bool = True,
            mp_kernel_size: int = 3, mp_stride: int = 2, mp_padding: int = 1,
            fcnv_kernel_size: int = 7, fcnv_stride: int = 2, fcnv_padding: int = 3,
            num_classes: int = 1000,
            groups: int = 1,
            width_per_group: int = 64,
            width_scale: int = 1,
            replace_stride_with_dilation: Optional[List[bool]] = None,
            dropout_linear: float = 0.25,
            hidden_layer_dim: int = None,
            **kan_kwargs
    ) -> None:
        super().__init__()

        self.input_channels = input_channels
        self.inplanes = 8 * width_scale
        self.hidden_layer_dim = hidden_layer_dim
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.use_first_maxpool = use_first_maxpool

        self.hidden_layer = None

        kan_kwargs_clean = kan_kwargs.copy()
        kan_kwargs_clean.pop('l1_decay', None)
        kan_kwargs_clean.pop('groups', None)

        kan_kwargs_fc = kan_kwargs.copy()
        kan_kwargs_fc.pop('groups', None)
        kan_kwargs_fc.pop('dropout', None)
        kan_kwargs_fc['dropout'] = dropout_linear

        if hidden_layer_dim is not None:
            fc_layers = [64 * width_scale * block.expansion, hidden_layer_dim, num_classes]
        else:
            fc_layers = [64 * width_scale * block.expansion, num_classes]

        if block in (KANBasicBlock, KANBottleneck):
            self.conv1 = KANConv2DLayer(input_channels, self.inplanes, kernel_size=fcnv_kernel_size, stride=fcnv_stride,
                                        padding=fcnv_padding, **kan_kwargs_clean)
            self.fc = KAN(fc_layers, **kan_kwargs_fc)

        elif block in (FastKANBasicBlock, FastKANBottleneck):
            self.conv1 = FastKANConv2DLayer(input_channels, self.inplanes, kernel_size=fcnv_kernel_size,
                                            stride=fcnv_stride, padding=fcnv_padding, **kan_kwargs_clean)
            self.fc = FastKAN(fc_layers, **kan_kwargs_fc)

        elif block in (KALNBasicBlock, KALNBottleneck):
            self.conv1 = KALNConv2DLayer(input_channels, self.inplanes, kernel_size=fcnv_kernel_size,
                                         stride=fcnv_stride, padding=fcnv_padding, **kan_kwargs_clean)
            self.fc = KALN(fc_layers, **kan_kwargs_fc)
        elif block in (KAGNBasicBlock, KAGNBottleneck):
            self.conv1 = KAGNConv2DLayer(input_channels, self.inplanes, kernel_size=fcnv_kernel_size,
                                         stride=fcnv_stride, padding=fcnv_padding, **kan_kwargs_clean)
            self.fc = KAGN(fc_layers, **kan_kwargs_fc)
        elif block in (KACNBasicBlock, KACNBottleneck):
            self.conv1 = KACNConv2DLayer(input_channels, self.inplanes, kernel_size=fcnv_kernel_size,
                                         stride=fcnv_stride, padding=fcnv_padding, **kan_kwargs_clean)
            self.fc = KACN(fc_layers, **kan_kwargs_fc)
        else:
            raise TypeError(f"Block {type(block)} is not supported")
        self.maxpool = None
        if use_first_maxpool:
            self.maxpool = nn.MaxPool2d(kernel_size=mp_kernel_size, stride=mp_stride, padding=mp_padding)

        self.layer1 = self._make_layer(block, 8 * width_scale, layers[0], **kan_kwargs)
        self.layer2 = self._make_layer(block, 16 * width_scale, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0],
                                       **kan_kwargs)
        self.layer3 = self._make_layer(block, 32 * width_scale, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1],
                                       **kan_kwargs)
        self.layer4 = self._make_layer(block, 64 * width_scale, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2],
                                       **kan_kwargs)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=dropout_linear)
        self.fc = nn.Linear(64 * width_scale * block.expansion if self.hidden_layer is None else hidden_layer_dim,
                            num_classes)

    def _make_layer(
            self,
            block: Type[Union[KANBasicBlock, FastKANBasicBlock, KALNBasicBlock, KACNBasicBlock,
                              KANBottleneck, FastKANBottleneck, KALNBottleneck, KACNBottleneck]],
            planes: int,
            blocks: int,
            stride: int = 1,
            dilate: bool = False,
            **kan_kwargs
    ) -> nn.Sequential:
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:

            if block in (KANBasicBlock, KANBottleneck):
                conv1x1 = partial(kan_conv1x1, **kan_kwargs)
            elif block in (FastKANBasicBlock, FastKANBottleneck):
                conv1x1 = partial(fast_kan_conv1x1, **kan_kwargs)
            elif block in (KALNBasicBlock, KALNBottleneck):
                conv1x1 = partial(kaln_conv1x1, **kan_kwargs)
            elif block in (KAGNBasicBlock, KAGNBottleneck):
                conv1x1 = partial(kagn_conv1x1, **kan_kwargs)
            elif block in (KACNBasicBlock, KACNBottleneck):
                conv1x1 = partial(kacn_conv1x1, **kan_kwargs)
            else:
                raise TypeError(f"Block {type(block)} is not supported")

            downsample = conv1x1(self.inplanes, planes * block.expansion, stride=stride)

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,
                base_width=self.base_width, dilation=previous_dilation, **kan_kwargs
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    **kan_kwargs
                )
            )

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        if self.use_first_maxpool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        if self.hidden_layer is not None:
            x = self.hidden_layer(x)
        x = flatten(x, 1)
        x = self.drop(x)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor, **kwargs) -> Tensor:
        return self._forward_impl(x)


class MoEResKANet(nn.Module):
    def __init__(
            self,
            block: Type[Union[MoEKALNBottleneck, MoEKALNBasicBlock]],
            layers: List[int],
            input_channels: int = 3,
            use_first_maxpool: bool = True,
            mp_kernel_size: int = 3, mp_stride: int = 2, mp_padding: int = 1,
            fcnv_kernel_size: int = 7, fcnv_stride: int = 2, fcnv_padding: int = 3,
            num_classes: int = 1000,
            groups: int = 1,
            width_per_group: int = 64,
            width_scale: int = 1,
            replace_stride_with_dilation: Optional[List[bool]] = None,
            num_experts: int = 8,
            noisy_gating: bool = True,
            k: int = 2,
            hidden_layer_dim: int = None,
            dropout_linear: float = 0.0,
            **kan_kwargs
    ) -> None:
        super().__init__()

        self.input_channels = input_channels
        self.inplanes = 16 * width_scale
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.use_first_maxpool = use_first_maxpool

        self.hidden_layer = None

        kan_kwargs_clean = kan_kwargs.copy()
        kan_kwargs_clean.pop('l1_decay', None)
        if block in (MoEKALNBottleneck, MoEKALNBasicBlock):
            self.conv1 = KALNConv2DLayer(input_channels, self.inplanes, kernel_size=fcnv_kernel_size,
                                         stride=fcnv_stride, padding=fcnv_padding, **kan_kwargs_clean)
            if hidden_layer_dim is not None:
                self.hidden_layer = kaln_conv1x1(64 * width_scale * block.expansion, hidden_layer_dim, **kan_kwargs)
        else:
            raise TypeError(f"Block {type(block)} is not supported")
        self.maxpool = None
        if use_first_maxpool:
            self.maxpool = nn.MaxPool2d(kernel_size=mp_kernel_size, stride=mp_stride, padding=mp_padding)

        self.layer1 = self._make_layer(block, 8 * width_scale, layers[0], **kan_kwargs)
        self.layer2 = self._make_layer(block, 16 * width_scale, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0],
                                       num_experts=num_experts, noisy_gating=noisy_gating, k=k,
                                       **kan_kwargs)
        self.layer3 = self._make_layer(block, 32 * width_scale, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1],
                                       num_experts=num_experts, noisy_gating=noisy_gating, k=k,
                                       **kan_kwargs)
        self.layer4 = self._make_layer(block, 64 * width_scale, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2],
                                       num_experts=num_experts, noisy_gating=noisy_gating, k=k,
                                       **kan_kwargs)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * width_scale * block.expansion if self.hidden_layer is None else hidden_layer_dim,
                            num_classes)
        self.drop = nn.Dropout(p=dropout_linear)

    def _make_layer(
            self,
            block: Type[Union[MoEKALNBottleneck,]],
            planes: int,
            blocks: int,
            stride: int = 1,
            dilate: bool = False,
            num_experts: int = 8,
            noisy_gating: bool = True,
            k: int = 2,
            **kan_kwargs
    ) -> nn.Module:
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            if block in (MoEKALNBottleneck, MoEKALNBasicBlock):
                kan_kwargs.pop('num_experts', None)
                kan_kwargs.pop('noisy_gating', None)
                kan_kwargs.pop('k', None)
                conv1x1 = partial(kaln_conv1x1, **kan_kwargs)
            else:
                raise TypeError(f"Block {type(block)} is not supported")

            downsample = conv1x1(self.inplanes, planes * block.expansion, stride=stride)

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,
                base_width=self.base_width, dilation=previous_dilation, num_experts=num_experts,
                noisy_gating=noisy_gating, k=k, **kan_kwargs
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    num_experts=num_experts,
                    noisy_gating=noisy_gating,
                    k=k,
                    **kan_kwargs
                )
            )

        return nn.ModuleList(layers)

    def _forward_layer(self, layer, x, train):
        moe_loss = 0
        for block in layer:
            x, _moe_loss = block(x, train)
            moe_loss += _moe_loss
        return x, moe_loss

    def _forward_impl(self, x: Tensor, train: bool = True) -> Tensor:
        x = self.conv1(x)
        if self.use_first_maxpool:
            x = self.maxpool(x)

        x, moe_loss1 = self._forward_layer(self.layer1, x, train)
        x, moe_loss2 = self._forward_layer(self.layer2, x, train)
        x, moe_loss3 = self._forward_layer(self.layer3, x, train)
        x, moe_loss4 = self._forward_layer(self.layer4, x, train)

        x = self.avgpool(x)
        if self.hidden_layer is not None:
            x = self.hidden_layer(x)
        x = flatten(x, 1)
        x = self.drop(x)
        x = self.fc(x)

        return x, (moe_loss1 + moe_loss2 + moe_loss3 + moe_loss4) / 4

    def forward(self, x: Tensor, train: bool = True) -> Tensor:
        return self._forward_impl(x, train)


def reskanet_18x32p(input_channels, num_classes, groups: int = 1, spline_order: int = 3, grid_size: int = 5,
                    base_activation: Optional[Callable[..., nn.Module]] = nn.GELU,
                    grid_range: List = [-1, 1], hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                    dropout_linear: float = 0.25, affine: bool = False):
    return ResKANet(KANBasicBlock, [2, 2, 2, 2],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    spline_order=spline_order, grid_size=grid_size, base_activation=base_activation,
                    grid_range=grid_range, hidden_layer_dim=hidden_layer_dim,
                    dropout_linear=dropout_linear,
                    dropout=dropout,
                    l1_decay=l1_decay,
                    affine=affine
                    )


def fast_reskanet_18x32p(input_channels, num_classes, groups: int = 1, grid_size: int = 5,
                         base_activation: Optional[Callable[..., nn.Module]] = nn.GELU,
                         grid_range: List = [-1, 1], hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                         dropout_linear: float = 0.25, affine: bool = False):
    return ResKANet(FastKANBasicBlock, [2, 2, 2, 2],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    grid_size=grid_size, base_activation=base_activation,
                    grid_range=grid_range, hidden_layer_dim=hidden_layer_dim,
                    dropout_linear=dropout_linear,
                    dropout=dropout,
                    l1_decay=l1_decay,
                    affine=affine)


def reskalnet_18x32p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                     hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                     dropout_linear: float = 0.25, affine: bool = False):
    return ResKANet(KALNBasicBlock, [2, 2, 2, 2],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine
                    )


def reskagnet_18x32p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                     hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                     dropout_linear: float = 0.25, affine: bool = False):
    return ResKANet(KAGNBasicBlock, [2, 2, 2, 2],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine
                    )


def reskalnet_18x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                     hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                     dropout_linear: float = 0.25, affine: bool = False):
    return ResKANet(KALNBasicBlock, [2, 2, 2, 2],
                    input_channels=input_channels,
                    use_first_maxpool=True,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale,
                    hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine
                    )


def moe_reskalnet_18x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                         num_experts: int = 8, noisy_gating: bool = True, k: int = 2,
                         hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                         dropout_linear: float = 0.25, affine: bool = False):
    return MoEResKANet(MoEKALNBasicBlock, [2, 2, 2, 2],
                       input_channels=input_channels,
                       use_first_maxpool=True,
                       fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                       num_classes=num_classes,
                       groups=groups,
                       width_per_group=64,
                       degree=degree,
                       width_scale=width_scale,
                       num_experts=num_experts,
                       noisy_gating=noisy_gating,
                       k=k, hidden_layer_dim=hidden_layer_dim,
                       dropout=dropout,
                       dropout_linear=dropout_linear,
                       l1_decay=l1_decay,
                       affine=affine)


def reskacnet_18x32p(input_channels, num_classes, groups: int = 1, degree: int = 3,
                     hidden_layer_dim=None, dropout: float = 0.0, l1_decay: float = 0.0,
                     dropout_linear: float = 0.25, affine: bool = False):
    return ResKANet(KACNBasicBlock, [2, 2, 2, 2],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine)


def reskalnet_50x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                     dropout: float = 0.15, dropout_linear: float = 0.25, l1_decay: float = 0.0,
                     hidden_layer_dim=None, affine: bool = False):
    return ResKANet(KALNBottleneck, [3, 4, 6, 3],
                    input_channels=input_channels,
                    use_first_maxpool=True,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    hidden_layer_dim=hidden_layer_dim,
                    affine=affine)


def moe_reskalnet_50x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                         num_experts: int = 8, noisy_gating: bool = True, k: int = 2,
                         hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                         l1_decay: float = 0.0, affine: bool = False):
    return MoEResKANet(MoEKALNBottleneck, [3, 4, 6, 3],
                       input_channels=input_channels,
                       use_first_maxpool=True,
                       fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                       num_classes=num_classes,
                       groups=groups,
                       width_per_group=64,
                       degree=degree,
                       width_scale=width_scale,
                       num_experts=num_experts,
                       noisy_gating=noisy_gating,
                       k=k, hidden_layer_dim=hidden_layer_dim,
                       dropout=dropout,
                       dropout_linear=dropout_linear,
                       l1_decay=l1_decay,
                       affine=affine
                       )


def reskalnet_101x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                      hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                      l1_decay: float = 0.0, affine: bool = False):
    return ResKANet(KALNBottleneck, [3, 4, 23, 3],
                    input_channels=input_channels,
                    use_first_maxpool=True,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine)


def reskagnet_101x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                      hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                      l1_decay: float = 0.0, affine: bool = False):
    return ResKANet(KAGNBottleneck, [3, 4, 23, 3],
                    input_channels=input_channels,
                    use_first_maxpool=True,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine)


def reskalnet_101x32p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                      hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                      l1_decay: float = 0.0, affine: bool = False):
    return ResKANet(KALNBottleneck, [3, 4, 23, 3],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine)


def moe_reskalnet_101x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                          num_experts: int = 8, noisy_gating: bool = True, k: int = 2,
                          hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                          l1_decay: float = 0.0, affine: bool = False):
    return MoEResKANet(MoEKALNBottleneck, [3, 4, 23, 3],
                       input_channels=input_channels,
                       use_first_maxpool=True,
                       fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                       num_classes=num_classes,
                       groups=groups,
                       width_per_group=64,
                       degree=degree,
                       width_scale=width_scale,
                       num_experts=num_experts,
                       noisy_gating=noisy_gating,
                       k=k, hidden_layer_dim=hidden_layer_dim,
                       dropout=dropout,
                       dropout_linear=dropout_linear,
                       l1_decay=l1_decay,
                       affine=affine
                       )


def reskalnet_152x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                      hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                      l1_decay: float = 0.0, affine: bool = False):
    return ResKANet(KALNBottleneck, [3, 8, 36, 3],
                    input_channels=input_channels,
                    use_first_maxpool=True,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine)


def reskalnet_152x32p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                      hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                      l1_decay: float = 0.0, affine: bool = False):
    return ResKANet(KALNBottleneck, [3, 8, 36, 3],
                    input_channels=input_channels,
                    use_first_maxpool=False,
                    fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                    num_classes=num_classes,
                    groups=groups,
                    width_per_group=64,
                    degree=degree,
                    width_scale=width_scale, hidden_layer_dim=hidden_layer_dim,
                    dropout=dropout,
                    dropout_linear=dropout_linear,
                    l1_decay=l1_decay,
                    affine=affine)


def moe_reskalnet_152x64p(input_channels, num_classes, groups: int = 1, degree: int = 3, width_scale: int = 1,
                          num_experts: int = 8, noisy_gating: bool = True, k: int = 2,
                          hidden_layer_dim=None, dropout: float = 0.15, dropout_linear: float = 0.25,
                          l1_decay: float = 0.0, affine: bool = False):
    return MoEResKANet(MoEKALNBottleneck, [3, 8, 36, 3],
                       input_channels=input_channels,
                       use_first_maxpool=True,
                       fcnv_kernel_size=3, fcnv_stride=1, fcnv_padding=1,
                       num_classes=num_classes,
                       groups=groups,
                       width_per_group=64,
                       degree=degree,
                       width_scale=width_scale,
                       num_experts=num_experts,
                       noisy_gating=noisy_gating,
                       k=k, hidden_layer_dim=hidden_layer_dim,
                       dropout=dropout,
                       dropout_linear=dropout_linear,
                       l1_decay=l1_decay,
                       affine=affine
                       )

In [55]:
X_array =  np.load('C:/ITMO_CAI/2023-2024/Wine_Aroma/X_array.npy')
Y_array =  np.load('C:/ITMO_CAI/2023-2024/Wine_Aroma/Y_array.npy')

In [56]:
X_train, X_test, y_train, y_test = train_test_split(X_array, Y_array, test_size=0.2, random_state=998)

In [57]:
X_train = torch.tensor(X_train)
y_train = torch.tensor(y_train)
X_test = torch.tensor(X_test)
y_test = torch.tensor(y_test)

In [58]:
print(len(X_train))
print(len(y_train))
print(len(X_test))
print(len(y_test))

359
359
90
90


In [59]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [60]:
# Создаем Dataset и DataLoader
dataset_train = CustomDataset(X_train, y_train)
dataset_test = CustomDataset(X_test, y_test)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=5, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=False)

In [61]:
for x, y in dataloader_train:
    print(x.size())

torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])
torch.Size([5, 44, 100])


In [62]:
# Создание экземпляра нейронной сети
net = ResKANet(layers=[16, 32, 64, 128], block = KANBasicBlock, input_channels = 1, num_classes=10)
# Печать архитектуры нейронной сети
print(net)

ResKANet(
  (conv1): KANConv2DLayer(
    (base_activation): GELU(approximate='none')
    (base_conv): ModuleList(
      (0): Conv2d(1, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    )
    (spline_conv): ModuleList(
      (0): Conv2d(8, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    )
    (layer_norm): ModuleList(
      (0): InstanceNorm2d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (prelus): ModuleList(
      (0): PReLU(num_parameters=1)
    )
  )
  (fc): Linear(in_features=64, out_features=10, bias=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): KANBasicBlock(
      (conv1): KANConv2DLayer(
        (base_activation): GELU(approximate='none')
        (base_conv): ModuleList(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (spline_conv): ModuleList(
          (0): Conv2

In [63]:
for param in net.parameters():
    param.data = param.data.double()

In [64]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma = 0.6
)

In [65]:
def accuracy(pred, label, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_labels = (pred > threshold).float()  # Threshold the predictions
    correct = (pred_labels == label).sum().item()  # Compare predictions with labels
    total = label.size(0) * label.size(1)  # Total number of labels
    return correct / total

In [66]:
device = 'cuda' # if torch.cuda.is_available() else 'cpu'
model = net.to(device)
loss_fn = loss_fn.to(device)

In [67]:
use_amp = True
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

In [68]:
epochs = 30

for epoch in range(epochs):
    loss_val = 0.0
    acc_val = 0.0
    for sample in tqdm(dataloader_train):
        matrix, label = sample[0].to(device), sample[1].to(device)
        optimizer.zero_grad()

        # Добавим измерение каналов к данным
        matrix = matrix.unsqueeze(1)
  
        with autocast(use_amp):
            pred = model(matrix)
            loss = loss_fn(pred, label.float())

        scaler.scale(loss).backward()
        loss_item = loss.item()
        loss_val += loss_item

        scaler.step(optimizer)
        scaler.update()

        acc_current = accuracy(pred.cpu().float(), label.cpu().float())
        acc_val += acc_current

    print(f'Epoch: [{epoch+1}/{epochs}], Loss: {loss_val/len(dataloader_train):.5f}, Accuracy: {acc_val/len(dataloader_train):.3f}')

100%|██████████| 72/72 [06:13<00:00,  5.19s/it]


Epoch: [1/30], Loss: 6.55452, Accuracy: 0.578


100%|██████████| 72/72 [05:56<00:00,  4.96s/it]


Epoch: [2/30], Loss: 3.72537, Accuracy: 0.663


100%|██████████| 72/72 [05:55<00:00,  4.94s/it]


Epoch: [3/30], Loss: 3.31697, Accuracy: 0.669


100%|██████████| 72/72 [06:03<00:00,  5.04s/it]


Epoch: [4/30], Loss: 3.02538, Accuracy: 0.682


100%|██████████| 72/72 [06:00<00:00,  5.01s/it]


Epoch: [5/30], Loss: 2.94530, Accuracy: 0.684


100%|██████████| 72/72 [05:54<00:00,  4.92s/it]


Epoch: [6/30], Loss: 2.84038, Accuracy: 0.687


100%|██████████| 72/72 [05:55<00:00,  4.94s/it]


Epoch: [7/30], Loss: 2.74176, Accuracy: 0.680


100%|██████████| 72/72 [05:54<00:00,  4.92s/it]


Epoch: [8/30], Loss: 2.50347, Accuracy: 0.700


100%|██████████| 72/72 [05:53<00:00,  4.91s/it]


Epoch: [9/30], Loss: 2.39559, Accuracy: 0.692


100%|██████████| 72/72 [05:56<00:00,  4.96s/it]


Epoch: [10/30], Loss: 2.42708, Accuracy: 0.694


100%|██████████| 72/72 [05:43<00:00,  4.77s/it]


Epoch: [11/30], Loss: 2.46536, Accuracy: 0.685


100%|██████████| 72/72 [05:48<00:00,  4.85s/it]


Epoch: [12/30], Loss: 2.36726, Accuracy: 0.690


100%|██████████| 72/72 [05:50<00:00,  4.86s/it]


Epoch: [13/30], Loss: 2.26860, Accuracy: 0.687


100%|██████████| 72/72 [05:41<00:00,  4.74s/it]


Epoch: [14/30], Loss: 2.26077, Accuracy: 0.691


100%|██████████| 72/72 [05:38<00:00,  4.70s/it]


Epoch: [15/30], Loss: 2.07177, Accuracy: 0.705


100%|██████████| 72/72 [05:38<00:00,  4.70s/it]


Epoch: [16/30], Loss: 2.09575, Accuracy: 0.704


100%|██████████| 72/72 [05:40<00:00,  4.73s/it]


Epoch: [17/30], Loss: 2.10438, Accuracy: 0.692


100%|██████████| 72/72 [05:37<00:00,  4.69s/it]


Epoch: [18/30], Loss: 1.94379, Accuracy: 0.703


100%|██████████| 72/72 [05:47<00:00,  4.82s/it]


Epoch: [19/30], Loss: 1.89352, Accuracy: 0.701


100%|██████████| 72/72 [06:02<00:00,  5.03s/it]


Epoch: [20/30], Loss: 1.97431, Accuracy: 0.691


100%|██████████| 72/72 [06:06<00:00,  5.08s/it]


Epoch: [21/30], Loss: 1.88644, Accuracy: 0.706


100%|██████████| 72/72 [06:03<00:00,  5.04s/it]


Epoch: [22/30], Loss: 1.79436, Accuracy: 0.709


100%|██████████| 72/72 [06:05<00:00,  5.08s/it]


Epoch: [23/30], Loss: 1.73550, Accuracy: 0.704


100%|██████████| 72/72 [05:51<00:00,  4.89s/it]


Epoch: [24/30], Loss: 1.78185, Accuracy: 0.709


100%|██████████| 72/72 [05:56<00:00,  4.95s/it]


Epoch: [25/30], Loss: 1.74719, Accuracy: 0.699


100%|██████████| 72/72 [05:49<00:00,  4.85s/it]


Epoch: [26/30], Loss: 1.72928, Accuracy: 0.703


100%|██████████| 72/72 [05:53<00:00,  4.90s/it]


Epoch: [27/30], Loss: 1.60821, Accuracy: 0.715


100%|██████████| 72/72 [05:50<00:00,  4.87s/it]


Epoch: [28/30], Loss: 1.73199, Accuracy: 0.702


100%|██████████| 72/72 [05:45<00:00,  4.80s/it]


Epoch: [29/30], Loss: 1.57693, Accuracy: 0.710


100%|██████████| 72/72 [05:52<00:00,  4.89s/it]

Epoch: [30/30], Loss: 1.54339, Accuracy: 0.717





In [69]:
loss_val = 0.0
acc_val = 0.0
for sample in tqdm(dataloader_test):
    matrix, label = sample[0].to(device), sample[1].to(device)

    # Добавим измерение каналов к данным
    matrix = matrix.unsqueeze(1)
  
    with autocast(use_amp):
        pred = model(matrix)
        loss = loss_fn(pred, label)

    loss_item = loss.item()
    loss_val += loss_item


    acc_current = accuracy(pred.cpu().float(), label.cpu().float())
    acc_val += acc_current


print(f'Loss: {loss_val/len(dataloader_test):.5f}, Accuracy: {acc_val/len(dataloader_test):.3f}')

100%|██████████| 18/18 [00:32<00:00,  1.83s/it]

Loss: 1.54388, Accuracy: 0.696





In [70]:
def aroma_map(matrix):
    matrix = matrix.unsqueeze(1)
    pred = model(matrix).to('cpu')


    df = pd.DataFrame(dict(
        r=torch.sigmoid(pred).detach().numpy()[0],
        theta=['Herbs and spices', 'Tobacco/Smoke', 'Wood', 'Berries', 'Citrus',
       'Fruits ', 'Nuts', 'Coffee', 'Chocolate/Cacao', 'Flowers']))
    
    fig = px.line_polar(df, r='r', theta='theta', line_close=True, template="plotly_dark")
    fig.update_traces(fill='toself', line_color='maroon')
    fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1])))
    fig.show()


In [71]:
aroma_map(X_train[42].to(device).unsqueeze(0))

In [72]:
def aroma_map_comparison(matrix, labels):
    matrix = matrix.unsqueeze(1)
    pred = model(matrix).to('cpu')
    labels = labels.to('cpu')

    print(labels)

    df = pd.DataFrame(dict(
        r=torch.sigmoid(pred).detach().numpy()[0],
        theta=['Herbs and spices', 'Tobacco/Smoke', 'Wood', 'Berries', 'Citrus',
       'Fruits ', 'Nuts', 'Coffee', 'Chocolate/Cacao', 'Flowers']))
    
    df['Label'] = 'Predict'
    
    df1 = pd.DataFrame(dict(
        r=labels.detach().numpy(),
        theta=['Herbs and spices', 'Tobacco/Smoke', 'Wood', 'Berries', 'Citrus',
       'Fruits ', 'Nuts', 'Coffee', 'Chocolate/Cacao', 'Flowers']))
    
    df1['Label'] = 'Experiment'
    
    fig = px.line_polar(pd.concat([df, df1]), color="Label",  r='r', theta='theta', line_close=True, line_shape='linear', color_discrete_sequence=['#008080', '#FFC0CB'],
                    template="plotly_dark")
    

    
    fig.update_traces(fill='toself')
    fig.update_layout(polar=dict(radialaxis=dict(visible=True, range=[0, 1])))
    fig.show()

In [73]:
for i in range(90):
    matrix, label = X_test[i].to(device).unsqueeze(0), y_test[i].to(device)
    aroma_map_comparison(matrix, label)

tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 1., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 1., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 1., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 1., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 1., 0., 1., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 1., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 1., 1., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 1., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 1., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 1., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 1., 1., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 0., 1., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 1., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 1., 0., 1., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 1., 0., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 1., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 1., 1., 1., 1., 1., 0., 0., 1., 1.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 1., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 1., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 1., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 1., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 1., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 0., 0., 1., 1., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 1., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 1., 0., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([0., 0., 1., 1., 0., 1., 0., 0., 1., 0.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 0., 1., 0.], dtype=torch.float64)


tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 1., 0., 0., 0., 0.], dtype=torch.float64)


tensor([1., 0., 0., 0., 1., 1., 0., 0., 0., 1.], dtype=torch.float64)


tensor([1., 0., 0., 1., 0., 0., 0., 1., 0., 1.], dtype=torch.float64)


tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1.], dtype=torch.float64)
