**Create a wrapper of conv2d that accepts wavelet filters and other parameters. It create a Conv2d layer and assign the weight as wavelet filters, disable gradient.**

In [None]:
import torch
from torch import nn
from create_filters import filter_bank
def wavelet2d(filters: torch.Tensor, in_channels: int, stride: int = 1, dilation: int = 1, kernel_dtype = torch.complex64) -> nn.Conv2d:
    '''
    Create nn.Conv2d with
    - bias = False
    - weights set to `filters`

    filters must have shape [C_filter, S, S]
    for band pass filter, C_filter = nb_angles

    Example:
    image <-- tensor of shape [1, 3, 128, 128], one color image of size 128*128
    filters <-- tensor of shape [4, 3, 3], 4 oriented wavelet filters of size 3*3
    conv2d = wavelet2d(filters, image.shape[1])
    result = conv2d(image) -> shape [1, 12, 128, 128], each rbg channel is convolved separately with each oriented filter.
    '''
    weight = filters.unsqueeze(1).repeat_interleave(in_channels, dim=0)
    out_channels = weight.shape[0] # in_channels * C_filter
    size = filters.shape[-1]
    padding = dilation*(size-1)//2 # always same size padding
    conv = nn.Conv2d(
        in_channels = in_channels,
        out_channels = out_channels,
        kernel_size = size,
        stride = stride,
        padding = padding,
        padding_mode = 'circular',
        dilation = dilation,
        groups = in_channels,
        dtype = kernel_dtype,
        bias = False
    )
    with torch.no_grad():
        conv.weight.copy_(weight)
    return conv

In [None]:
class dense(nn.Module):
    def __init__(self, scale_J:int,       # number of scales
                       angle_K:int,       # number of orientation
                       image_shape:tuple, # image_channel, square_image_size
                       kernel_size:int,   # square_filter_size
                       nb_class:int = None  # nb of classification class
                ):
        super().__init__()
        self.scale_J = scale_J
        self.angle_K = angle_K
        self.image_channel, self.image_size = image_shape
        self.kernel_size = kernel_size
        
        # check valid parameter
        
        self.load_filters()
        self.sequential_conv = nn.ModuleList()
        #self.sequential_pooling = nn.ModuleList()
        in_channel = self.image_channel
        for j in range(scale_J):
            conv = wavelet2d(self.filters[j], in_channel)
            #pool = nn.AvgPool2d(2, 2)
            self.sequential_conv.append(conv)
            #self.sequential_pooling.append(pool)
            in_channel = in_channel * (angle_K + 1) # non_linear doesn't increase the channel, only conv
        self.out_dim = in_channel * (self.image_size//2**scale_J)**2
        self.pooling = nn.AvgPool2d(2**scale_J, 2**scale_J)
        if nb_class != None:
            self.linear = nn.Linear(self.out_dim, nb_class)
        else:
            self.linear = None

    def non_linear(self, imgs):
        return torch.abs(imgs)

    def cuda(self):
        # for conv in self.sequential_conv:
        #     conv.to("cuda")
        for wavelet_filter in self.filters:
            wavelet_filter.to("cuda")
        if self.linear:
            self.linear.to("cuda")
        return self                
        
    def load_filters(self):
        # [J, K, N, N], with different scale factor!
        # self.filters = torch.zeros((self.scale_J, self.angle_K, self.kernel_size, self.kernel_size), dtype=torch.complex64)
        # wavelet_filter = torch.load('./filters/morlet_S'+str(self.kernel_size)+'_K'+str(self.angle_K)+'.pt')
        # for j in range(self.scale_J):
        #     self.filters[j] = wavelet_filter * (2**j)
        S = self.kernel_size
        L = self.angle_K
        self.filters = []
        for j in range(self.scale_J):
            wavelet_filter = filter_bank(8, S, L) / (2**(2*j)) # at scale j, the filter is scaled by 2^j
            S = 2 * S - 1  # next scale has size 2*S - 1
            self.filters.append(wavelet_filter)  # scale the filter by 2^j
    
        
    def forward(self, img):
        for conv in self.sequential_conv:
            result = self.non_linear(conv(img.to(torch.complex64)))
            img = torch.cat([img, result], dim=1)
        img = self.pooling(img)
        if self.linear:
            img = self.linear(img.reshape(img.shape[0], -1))
        return img