In [1]:
!pip install PyWavelets



In [2]:
import torch
import torch.nn as nn
import pywt
import numpy as np

def get_wavelet_kernels(wavelet_name='haar'):
    wavelet = pywt.Wavelet(wavelet_name)

    # Get low-pass and high-pass filters for decomposition
    filter_bank = wavelet.filter_bank  # (dec_lo, dec_hi, rec_lo, rec_hi)
    dec_lo = np.outer(filter_bank[0], filter_bank[0])  # Low-pass filter in 2D
    dec_hi = np.outer(filter_bank[1], filter_bank[1])  # High-pass filter in 2D

    # Create the wavelet filters for 2D convolution
    filters = torch.tensor([dec_lo, dec_hi], dtype=torch.float32)

    return filters

# Example: Haar wavelet kernel
wavelet_kernels = get_wavelet_kernels('haar')
print("Wavelet kernels: ", wavelet_kernels.shape)


KeyboardInterrupt: 

In [None]:
class WaveletConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, wavelet_name='haar', stride=1, padding=0):
        super(WaveletConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding

        # Create wavelet kernels
        wavelet_kernels = get_wavelet_kernels(wavelet_name)
        wavelet_kernels = wavelet_kernels.unsqueeze(1)  # Add input channel dimension

        # Repeat wavelet kernels to match out_channels
        # out_channels should be divisible by the number of wavelet kernels (here assumed as 1 for simplicity)
        self.filters = nn.Parameter(wavelet_kernels.repeat(self.out_channels, 1, 1, 1), requires_grad=False)

    def forward(self, x):
        # Perform convolution using wavelet kernels
        # Ensure `groups=1` if not performing grouped convolutions
        x = nn.functional.conv2d(x, self.filters, stride=self.stride, padding=self.padding, groups=1)
        return x


class WaveletCNN(nn.Module):
    def __init__(self):
        super(WaveletCNN, self).__init__()
        # Wavelet-based convolution layers
        self.wavelet_conv1 = WaveletConv2d(in_channels=1, out_channels=16, wavelet_name='haar', stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.wavelet_conv2 = WaveletConv2d(in_channels=16, out_channels=32, wavelet_name='haar', stride=1, padding=1)

        # Fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Pass through wavelet-based convolutions
        x = self.pool(torch.relu(self.wavelet_conv1(x)))
        x = self.pool(torch.relu(self.wavelet_conv2(x)))

        # Flatten for fully connected layers
        x = x.view(-1, 32 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)

        return x

# Example usage
model = WaveletCNN()
x = torch.randn(1, 1, 28, 28)  # Batch size 1, 1 channel, 28x28 image
output = model(x)
print(output)


In [5]:
import torch
import torch.nn as nn
import pywt
import numpy as np

def get_wavelet_kernels(wavelet_name='haar'):
    wavelet = pywt.Wavelet(wavelet_name)

    # Get low-pass and high-pass filters for decomposition
    filter_bank = wavelet.filter_bank  # (dec_lo, dec_hi, rec_lo, rec_hi)
    dec_lo = np.outer(filter_bank[0], filter_bank[0])  # Low-pass filter in 2D
    dec_hi = np.outer(filter_bank[1], filter_bank[1])  # High-pass filter in 2D

    # Create the wavelet filters for 2D convolution
    filters = torch.tensor([dec_lo, dec_hi], dtype=torch.float32)

    return filters

class WaveletConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, wavelet_name='haar', stride=1, padding=0):
        super(WaveletConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding

        # Get wavelet kernels (2 filters: low-pass and high-pass)
        wavelet_kernels = get_wavelet_kernels(wavelet_name)  # Shape: (2, filter_height, filter_width)

        # Expand to match in_channels and out_channels
        # Repeat the filters to match the required number of channels
        filters = wavelet_kernels.unsqueeze(1)  # Shape: (2, 1, filter_height, filter_width)
        filters = filters.repeat(out_channels // 2, in_channels, 1, 1)  # Adjust to match in_channels and out_channels

        # Register the filters as a non-learnable parameter
        self.filters = nn.Parameter(filters, requires_grad=False)

    def forward(self, x):
        # Perform convolution using wavelet kernels
        x = nn.functional.conv2d(x, self.filters, stride=self.stride, padding=self.padding)
        return x

class WaveletCNN(nn.Module):
    def __init__(self):
        super(WaveletCNN, self).__init__()
        # Wavelet-based convolution layers
        self.wavelet_conv1 = WaveletConv2d(in_channels=1, out_channels=16, wavelet_name='haar', stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.wavelet_conv2 = WaveletConv2d(in_channels=16, out_channels=32, wavelet_name='haar', stride=1, padding=1)

        # Fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Pass through wavelet-based convolutions
        x = self.pool(torch.relu(self.wavelet_conv1(x)))  # Output: [batch_size, 16, 14, 14]
        x = self.pool(torch.relu(self.wavelet_conv2(x)))  # Output: [batch_size, 32, 7, 7]

        # Flatten for fully connected layers
        x = x.view(-1, 32 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)

        return x

# Example usage
model = WaveletCNN()
x = torch.randn(1, 1, 28, 28)  # Batch size 1, 1 channel, 28x28 image
output = model(x)
print(output)


  filters = torch.tensor([dec_lo, dec_hi], dtype=torch.float32)


tensor([[  6.8028,  -4.7056,   3.5435,   4.5865,   2.8061,  -3.7799,   4.0897,
           6.8766,  -8.1368, -14.6483]], grad_fn=<AddmmBackward0>)


# ***Correlation computation***

In [None]:
import numpy as np
from scipy.integrate import dblquad

# Assume U_x is defined somewhere and returns a complex number
def U_x(u, lambd, alpha):
    # Example of how it could be structured
    # return some complex function based on u, lambd, and alpha
    pass

def C_x(lambd, alpha, lambd_prime, alpha_prime, u_bounds):
    """
    Calculate C_x(λ, α, λ', α') by integrating U_x(u, λ, α) * conj(U_x(u, λ', α')) over R^2.

    Args:
    - lambd: λ parameter.
    - alpha: α parameter.
    - lambd_prime: λ' parameter.
    - alpha_prime: α' parameter.
    - u_bounds: The bounds of integration for the u vector in R^2, given as ((u1_min, u1_max), (u2_min, u2_max)).

    Returns:
    - result: The value of the integral.
    """

    # Define the real and imaginary parts of the function to integrate
    def integrand(u1, u2, lambd, alpha, lambd_prime, alpha_prime):
        u = np.array([u1, u2])
        return U_x(u, lambd, alpha) * np.conj(U_x(u, lambd_prime, alpha_prime))

    # Perform the integration over u1 and u2 (R^2 space)
    u1_bounds, u2_bounds = u_bounds
    result, error = dblquad(
        lambda u1, u2: np.real(integrand(u1, u2, lambd, alpha, lambd_prime, alpha_prime)),  # Real part
        u1_bounds[0], u1_bounds[1],
        lambda u1: u2_bounds[0], lambda u1: u2_bounds[1]
    )

    result_imag, error_imag = dblquad(
        lambda u1, u2: np.imag(integrand(u1, u2, lambd, alpha, lambd_prime, alpha_prime)),  # Imaginary part
        u1_bounds[0], u1_bounds[1],
        lambda u1: u2_bounds[0], lambda u1: u2_bounds[1]
    )

    # Combine the real and imaginary parts to form the complex result
    return result + 1j * result_imag

# Example usage (assuming bounds of integration are known)
# u_bounds = ((u1_min, u1_max), (u2_min, u2_max))
# result = C_x(lambd, alpha, lambd_prime, alpha_prime, u_bounds)
