In [None]:
import pyshearlab
from PIL import Image
import numpy as np
from skimage.transform import resize
import matplotlib.pyplot as plt
from pyshearlab import dfilters, modulate2, SLprepareFilters2D, SLgetShearletIdxs2D, MirrorFilt, SLupsample, SLpadArray, SLdshear
import scipy

fftlib = np.fft

In [None]:
scalingFilter = np.array(
    [
        0.0104933261758410,
        -0.0263483047033631,
        -0.0517766952966370,
        0.276348304703363,
        0.582566738241592,
        0.276348304703363,
        -0.0517766952966369,
        -0.0263483047033631,
        0.0104933261758408,
    ]
)

#### The bandpass-lowpass function has to take in the scalingfilter weights and give us a qmf of that guy.  There is also the issue of the support function for the shearlet which is in this case the directional filter.  Technically I believe that this can also be anything, and we can have that as another weight of the module.  It seems logical also to use these directional filters here or a gaussian as CoShRem does.

+ because this function is called at every forward it has to be fast
+ because this function is the edge between the weight and the computation this function also has to be differentiable
+ because we want our operations to occur in the real this function needs to return the real-valued filter instead of the representation in the complex plane

In [None]:
def get_bandpass_lowpass(scalingFilter):
    """
    this function takes a 1d array "scaling filter" which is a qmf that sums to unity.
    """

    waveletFilter = np.power(-1, np.arange(scalingFilter.size)) * scalingFilter[::-1]

    lowpass = np.convolve(scalingFilter, SLupsample(scalingFilter, 2, 1))
    bandpass = np.convolve(scalingFilter, SLupsample(waveletFilter, 2, 1))

    # if we wanted to convert these to the frequency domain then this is how we would do it.
    """    
    rows, cols = 256, 256

    bandpass = np.zeros(
        (rows, cols, 1), dtype=complex
    )  # these filters partition the frequency plane into different scales

    bandpass[:, :, j] = fftlib.fftshift(
        fftlib.fft2(fftlib.ifftshift(SLpadArray(filterHigh[j], np.array([rows, cols]))))
    )

    lowpass = fftlib.fftshift(
        fftlib.fft2(
            fftlib.ifftshift(
                SLpadArray(np.outer(filterLow[0], filterLow[0]), np.array([rows, cols]))
            )
        )
    )"""

    return bandpass, lowpass

#### The function that computes the wedge is called only once, so does not have to be fast.  It should only return a pytorch tensor containing the wedge.

+ this function should operate in real land so we need to make sure we are replacing fourier convolutions with the torch functional conv2d
+ when using the conv2d we need to be mindful that it is actually cross-correlation and also that the resulting array will be larger (not smaller) and thus we need to pad accordingly.
+ eventually we want the option to replace the directional filters with a weight or a gaussian
    + if we do replace it with a weight then we are going to need to get the wedge very frequently which will mean these operations have to be fast and differentiable.

In [None]:
def get_wedge(shearLevels, scalingFilter):
    rows, cols = 256, 256
    shearLevels = [shearLevels]

    wedge = [None] * (
        max(shearLevels) + 1
    )

    h0, _ = dfilters("dmaxflat4", "d")
    h0 /= np.sqrt(2)
    # what does modulate2 do? (TODO)
    """
    origin = np.floor(size / 2) + 1 + center

    n2 = np.arange(size[1]) - origin[1] + 1

    size = np.asarray(h0.shape)

    m2 = np.power(-1, n2)

    y = x * np.tile(m2, np.array([size[0], 1]))
    """
    directionalFilter = modulate2(h0, "c")

    directionalFilter = directionalFilter / sum(sum(np.absolute(directionalFilter)))

    filterLow2 = [None] * (max(shearLevels) + 1)
    filterLow2[-1] = scalingFilter

    for j in range(len(filterLow2) - 2, -1, -1):
        filterLow2[j] = np.convolve(filterLow2[-1], SLupsample(filterLow2[j + 1], 2, 1))


    filterLow2[-1].shape = (1, len(filterLow2[-1]))
    for shearLevel in np.unique(shearLevels):
        wedge[shearLevel] = np.zeros(
            (rows, cols, int(np.floor(np.power(2, shearLevel + 1) + 1))), dtype=complex
        )

        directionalFilterUpsampled = SLupsample(
            directionalFilter, 1, np.power(2, shearLevel + 1) - 1
        )
        # upsample the directional filter in the vertical direction by inserting zero rows

        plt.imshow(directionalFilterUpsampled)
        plt.show()

        filterLow2[-1 - shearLevel] = filterLow2[-1 - shearLevel][np.newaxis]

        # convolve the directional filter with the lowpass filter
        wedgeHelp = scipy.signal.convolve2d(
            directionalFilterUpsampled,
            np.transpose(filterLow2[len(filterLow2) - shearLevel - 1]),
        )
        # pad the resulting array to the size of the input image
        wedgeHelp = SLpadArray(wedgeHelp, np.array([rows, cols]))
        plt.imshow(wedgeHelp)
        plt.show()

        # upsample the input-size image in the columns dimension by 2**(shear level)
        wedgeUpsampled = SLupsample(wedgeHelp, 2, np.power(2, shearLevel) - 1)
        plt.imshow(wedgeUpsampled)
        plt.show()

        # pad the filter low 2 to the size og the upsampled wedge
        lowpassHelp = SLpadArray(
            filterLow2[len(filterLow2) - max(shearLevel - 1, 0) - 1],
            np.asarray(wedgeUpsampled.shape),
        )

        # convolve the low pass filter which has been padded with the wedge upsampled
        if shearLevel >= 1:
            # the vast majority of these operations can be ignored with a traditional convolution that disregards all of the zeros here.
            # this is a rather foolish way of performing this computation.
            wedgeUpsampled = fftlib.fftshift(
                fftlib.ifft2(
                    fftlib.ifftshift(
                        fftlib.fftshift(fftlib.fft2(fftlib.ifftshift(lowpassHelp))) * fftlib.fftshift(fftlib.fft2(fftlib.ifftshift(wedgeUpsampled)))
                    )
                )
            )
            plt.imshow(wedgeUpsampled.real)
            plt.show()

            
        lowpassHelpFlip = np.fliplr(lowpassHelp)

        for k in range(-np.power(2, shearLevel), np.power(2, shearLevel) + 1):

            wedgeUpsampledSheared = SLdshear(wedgeUpsampled, k, 2)

            if shearLevel >= 1:
                # the vast majority of these operations can be ignored with a traditional convolution that disregards all of the zeros here.
                # this is a rather foolish way of performing this computation.
                wedgeUpsampledSheared = fftlib.fftshift(
                    fftlib.ifft2(
                        fftlib.ifftshift(
                            fftlib.fftshift(fftlib.fft2(fftlib.ifftshift(lowpassHelpFlip))) * fftlib.fftshift(fftlib.fft2(fftlib.ifftshift(wedgeUpsampledSheared)))
                        )
                    )
                )

            # and now it is clear why we padded these arrays to such an extent - we needed to do that so we could downsample them in this step.
            wedge[shearLevel][:, :, int(np.fix(np.power(2, shearLevel)) - k)] = (
                fftlib.fftshift(
                    fftlib.fft2(
                        fftlib.ifftshift(
                            np.power(2, shearLevel) * wedgeUpsampledSheared[:, 0 : np.power(2, shearLevel) * cols - 1 : np.power(2, shearLevel)]
                        )
                    )
                )
            )
        plt.imshow(fftlib.fftshift(fftlib.ifft2(fftlib.ifftshift(wedge[shearLevel][..., 0]))).real)
        plt.show()

    wedge = wedge[1:]

In [None]:
shearletIdxs = SLgetShearletIdxs2D(shearLevels)
nShearlets = shearletIdxs.shape[0]

bandpass_shearlets = []
lowpass_shearlets = []


# compute the shearlets inferring the indexing structure

# cone1:
for w in wedge:
    for level in range(w.shape[-1]):
        # cone 1:
        bandpass_shearlet = w[..., level] * np.conj(bandpass[..., 0])
        lowpass_shearlet = w[..., level] * np.conj(lowpass)

        bandpass_shearlets.append(bandpass_shearlet)
        lowpass_shearlets.append(lowpass_shearlet)

        # cone 2:
        if level > 0 and level < w.shape[-1] - 1:
            bandpass_shearlet = np.transpose(w[..., level] * np.conj(bandpass[..., 0]))
            lowpass_shearlet = np.transpose(w[..., level] * np.conj(lowpass))

            bandpass_shearlets.append(bandpass_shearlet)
            lowpass_shearlets.append(lowpass_shearlet)


interleaved = sum([[high, low] for high, low in zip(bandpass_shearlets, lowpass_shearlets)], [])
shearlets = np.stack(interleaved, -1)

In [None]:
# the vectorized version is easy:
bandpass_wedge = np.concatenate((wedge * np.conj(bandpass), np.transpose(wedge[..., 1:-1] * np.conj(bandpass), (1, 0, 2))), -1)
lowpass_wedge = np.concatenate((wedge * np.conj(lowpass[..., np.newaxis]), np.transpose(wedge[..., 1:-1] * np.conj(lowpass[..., np.newaxis]), (1, 0, 2))), -1)
shearlets = np.concatenate((bandpass_wedge, lowpass_wedge), -1)

# in fact we do not need to compute the large product four times we can do it only twice:
cone1_bandpass = wedge * np.conj(bandpass)
cone1_lowpass = wedge * np.conj(lowpass[..., np.newaxis])

cone1_bandpass = np.concatenate((cone1_bandpass, np.transpose(cone1_bandpass[..., 1:-1], (1, 0, 2))), -1)
cone1_lowpass = np.concatenate((cone1_lowpass, np.transpose(cone1_lowpass[..., 1:-1], (1, 0, 2))), -1)

assert np.allclose(shearlets, np.concatenate((cone1_bandpass, cone1_lowpass), -1))

shearlets = np.concatenate((np.concatenate((cone1_bandpass, np.transpose(cone1_bandpass[..., 1:-1], (1, 0, 2))), -1), lowpass_wedge), -1)


In [None]:
shearlets.shape

In [None]:
for i in range(shearlets.shape[-1]):
    plt.imshow(shearlets[..., i].real)
    plt.show()