In [1]:
# Convolutional kernel for LUT-based models
import torch
import torch.nn as nn

from difflut.models.feedforward import feedforward_core

  from .dwn_node import DWNNode
  from .probabilistic_node import ProbabilisticNode
  from .probabilistic_stable_node import ProbabilisticStableNode
  from .hybrid_node import HybridNode
  from .fourier_node import FourierNode
  from .dwn_stable_node import DWNStableNode


In [2]:
class ConvolutionalLUTLayer(nn.Module):
    """
    Convolutional layer using LUT-based nodes
    """
    
    def __init__(
            self,
            tree_depth: int,
            in_channels: int,
            out_channels: int,
            receptive_field: int | tuple[int, int] = 5,
            stride: int | tuple[int, int] = 1,
            padding: int | tuple[int, int] = 0,
            node_type: str = 'dwn',
            layer_type: str = 'random',
            n_inputs_per_node: int = 6,
            # node_kwargs: dict | None = None
            ):
        super(ConvolutionalLUTLayer, self).__init__()
        self.tree_depth = tree_depth
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.receptive_field = self._pair(receptive_field)
        self.stride = self._pair(stride)
        self.padding = self._pair(padding)
        self.node_type = node_type
        self.layer_type = layer_type
        self.n_inputs_per_node = n_inputs_per_node
        # self.node_kwargs = node_kwargs if node_kwargs is not None else {}

        # Create trees (one for each output channel)
        # Each tree is a small feedforward network of LUT nodes
        hidden_layers = [self.n_inputs_per_node ** (self.tree_depth - i) for i in range(self.tree_depth + 1)]
        self.trees = nn.ModuleList()
        for _ in range(out_channels):
            tree = feedforward_core(
                input_size=in_channels * self.receptive_field[0] * self.receptive_field[1],
                hidden_sizes=hidden_layers,
                node_type=self.node_type,
                layer_type=self.layer_type,
                n_inputs=self.n_inputs_per_node,
                # node_kwargs=self.node_kwargs
            )
            self.trees.append(tree)
        
        # For convolution, we use the unfold operation
        self.unfold = nn.Unfold(kernel_size=receptive_field, padding=0, stride=1)

    def _pair(self, x: int | tuple[int, int]) -> tuple[int, int]:
        if isinstance(x, int):
            return (x, x)
        return x

    def forward(self, x):

        batch_size = x.shape[0]
        
        # Extract patches: (batch, patch_size, num_patches)
        patches = self.unfold(x)
        num_patches = patches.shape[2]
        
        # Reshape to (batch*num_patches, patch_size)
        patches = patches.transpose(1, 2).contiguous()
        patches = patches.view(-1, self.receptive_field * self.in_channels)

        # Process each patch through each tree
        output = [tree(patches) for tree in self.trees]
        output = torch.stack(output, dim=1)  # (batch*num_patches, out_channels)

        output = output.view(batch_size, num_patches, self.out_channels)
        output = output.transpose(1, 2)  # (batch, out_channels, num_patches)
        
        # Calculate output spatial dimensions
        out_h = (x.shape[2] + 2 * self.padding[0] - self.receptive_field[0]) // self.stride[0] + 1
        out_w = (x.shape[3] + 2 * self.padding[1] - self.receptive_field[1]) // self.stride[1] + 1

        output = output.view(batch_size, self.out_channels, out_h, out_w)
        
        return output
        

In [3]:
conv_lut_layer = ConvolutionalLUTLayer(
    tree_depth=3,
    in_channels=1,
    out_channels=2,
    receptive_field=5,
    stride=1,
    padding=0,
    node_type='dwn',
    layer_type='random',
    n_inputs_per_node=6,
)


Building DiffLUT model:
  Node type: dwn
  Input size: 25
  Hidden layers: [216, 36, 6, 1]


TypeError: difflut.nodes.dwn_node.DWNNode() got multiple values for keyword argument 'input_dim'