## Refactor wavelet initialisation

Notebook to play around with classes and objects during the great Kymatio refactor of 2022

In [1]:
from learnable_wavelets.models.models_factory import baseModelFactory, topModelFactory
from learnable_wavelets.models.sn_hybrid_models import sn_HybridModel
from learnable_wavelets.models.camels_models import get_architecture 
from learnable_wavelets.camels.camels_dataset import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Check if CUDA available
if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
    use_cuda=True
else:
    print('CUDA Not Available')
    device = torch.device('cpu')
    use_cuda=False

CUDA Not Available


In [3]:
camels_path="/home/chrisp/Data/CAMELs/"
fparams    = camels_path+"/params_IllustrisTNG.txt"
fmaps      = [
              camels_path+"maps_Mtot.npy"
             ]

In [4]:
## dataloader options
seed=123

batch_size=10
splits=1
fmaps_norm=[None]
num_workers=1
rot_flip_in_mem=True
channels=1
features=1

train_loader = create_dataset_multifield('train', seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                                                 num_workers=num_workers, rot_flip_in_mem=rot_flip_in_mem, verbose=True)

Found 1 channels
Reading data...
6.054e+09 < F(all|orig) < 2.176e+15
9.782 < F(all|resc)  < 15.338
-2.696 < F(all|norm) < 8.631
Channel 0 contains 7200 maps
-2.696 < F < 8.631



In [5]:
scatteringBase = baseModelFactory( #creat scattering base model
    architecture='scattering',
    J=2,
    N=256,
    M=256,
    channels=channels,
    max_order=2,
    initialization="Random",
    seed=234,
    learnable=True,
    lr_orientation=0.03,
    lr_scattering=0.03,
    skip=True,
    split_filters=True,
    filter_video=False,
    subsample=4,
    device=device,
    use_cuda=use_cuda,
    plot=False
)

## Now create a network to follow the scattering layers
## can be MLP, linear, or cnn at the moment
## (as in https://github.com/bentherien/ParametricScatteringNetworks/ )
top = topModelFactory( #create cnn, mlp, linear_layer, or other
    base=scatteringBase,
    architecture="linear_layer",
    num_classes=features,
    width=3,
    average=True,
    use_cuda=use_cuda
)

hybridModel = sn_HybridModel(scatteringBase=scatteringBase, top=top, use_cuda=use_cuda)
model=hybridModel
model.to(device=device)

(256, 256)
(128, 128)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


sn_HybridModel(
  (scatteringBase): sn_ScatteringBase()
  (top): sn_LinearLayer(
    (fc1): Linear(in_features=73, out_features=1, bias=True)
    (bn0): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [6]:
x,y=next(iter(train_loader))
model.to(device=device)
x=x.to(device=device)

In [7]:
def do_convolutions(x, backend, J, phi, wavelets, max_order,
        split_filters, subsample):
    """ Function to take an input image and perform a series of scattering
    convolutions."""
    subsample_fourier = backend.subsample_fourier
    modulus = backend.modulus
    fft = backend.fft
    cdgmm = backend.cdgmm
    concatenate = backend.concatenate
    
    wavelets = wavelets.real.contiguous().unsqueeze(3)

    # Define lists for output.
    out_S_0, out_S_1, out_S_2 = [], [], []
    
    ## Map to complex
    complex_maps = x.new_zeros(x.shape + (2,))
    complex_maps[..., 0] = x

    U_0_c = fft(complex_maps, 'C2C')
    print(U_0_c.shape)
    print(U_0_c)
    

    # First low pass filter
    U_1_c = cdgmm(U_0_c, phi[0])
    U_1_c = subsample_fourier(U_1_c, k=subsample)


    S_0 = fft(U_1_c, 'C2R', inverse=True)

    out_S_0.append({'coef': S_0})

    if split_filters:
        for n1 in range(int(len(wavelets)/2)):

            ## Wavelet convolution
            
            U_1_c = cdgmm(U_0_c, wavelets[n1])

            U_1_c = fft(U_1_c, 'C2C', inverse=True)
            U_1_c = modulus(U_1_c)
            U_1_c = fft(U_1_c, 'C2C')

            ## Second low pass filter
            S_1_c = cdgmm(U_1_c, phi[0])
            S_1_c = subsample_fourier(S_1_c, k=subsample)

            S_1_r = fft(S_1_c, 'C2R', inverse=True)

            out_S_1.append({'coef': S_1_r})

            if max_order < 2:
                continue
            for n2 in range(int(len(psi)/2),len(psi)):
                

                U_2_c = cdgmm(U_1_c, wavelets[n2])
                U_2_c = fft(U_2_c, 'C2C', inverse=True)
                U_2_c = modulus(U_2_c)
                U_2_c = fft(U_2_c, 'C2C')

                ## Low pass filter
                S_2_c = cdgmm(U_2_c, phi[0])
                
                S_2_c = subsample_fourier(S_2_c, k=subsample)

                S_2_r = fft(S_2_c, 'C2R', inverse=True)
                

                out_S_2.append({'coef': S_2_r})
    else:
        for n1 in range(len(wavelets)):
            ## Wavelet convolution
            U_1_c = cdgmm(U_0_c, wavelets[n1])

            U_1_c = fft(U_1_c, 'C2C', inverse=True)
            U_1_c = modulus(U_1_c)
            U_1_c = fft(U_1_c, 'C2C')

            ## Second low pass filter
            S_1_c = cdgmm(U_1_c, phi[0])
            S_1_c = subsample_fourier(S_1_c, k=subsample)

            S_1_r = fft(S_1_c, 'C2R', inverse=True)

            out_S_1.append({'coef': S_1_r})

            if max_order < 2:
                continue
            for n2 in range(len(wavelets)):
                
                U_2_c = cdgmm(U_1_c, wavelets[n2])
                U_2_c = fft(U_2_c, 'C2C', inverse=True)
                U_2_c = modulus(U_2_c)
                U_2_c = fft(U_2_c, 'C2C')

                ## Low pass filter
                S_2_c = cdgmm(U_2_c, phi[0])
                S_2_c = subsample_fourier(S_2_c, k=subsample)
                S_2_r = fft(S_2_c, 'C2R', inverse=True)
                

                out_S_2.append({'coef': S_2_r})

    out_S = []
    out_S.extend(out_S_0)
    out_S.extend(out_S_1)
    out_S.extend(out_S_2)

    out_S = concatenate([x['coef'] for x in out_S])

    return out_S

def convolve_fields(input, backend, J, phi, wavelets, max_order, split_filters, subsample):
    """  
        Wrapper function for a loop that will convovle each wavelet with the input fields

        Parameters:
            input      -- input data
            psi        -- dictionnary of filters that is used in the kymatio code
            split_filters -- split first and second order filters
        Returns:
            S -- Fields after being convolved with wavelets
    """

    batch_shape = input.shape[:-2]
    signal_shape = input.shape[-2:]

    input = input.reshape((-1,) + signal_shape)

    S = do_convolutions(input, backend, J, phi, wavelets,
                        max_order, split_filters, subsample)

    ## S will always be a numpy array
    scattering_shape = S.shape[-3:]
    S = S.reshape(batch_shape + scattering_shape)

    return S

In [8]:
x.shape

torch.Size([10, 1, 256, 256])

In [9]:
convolve_fields(x, scatteringBase.backend, scatteringBase.J, scatteringBase.phi, scatteringBase.wavelets,
                                    scatteringBase.max_order, scatteringBase.split_filters,scatteringBase.subsample)

torch.Size([10, 256, 256, 2])
tensor([[[[-4.0709e+04,  0.0000e+00],
          [ 6.2421e+03, -1.6293e+03],
          [-3.3327e+03, -8.7347e+03],
          ...,
          [-3.1848e+03, -1.1559e+03],
          [-3.3327e+03,  8.7347e+03],
          [ 6.2421e+03,  1.6293e+03]],

         [[-1.0732e+04, -1.3816e+04],
          [ 3.4436e+03, -6.0446e+03],
          [ 7.3200e+02,  2.7856e+03],
          ...,
          [ 3.3609e+03,  2.4519e+02],
          [-1.0972e+03, -1.7042e+03],
          [ 4.7773e+03,  8.0153e+03]],

         [[-3.7985e+03,  5.6110e+03],
          [-2.9254e+03,  1.2903e+02],
          [-3.6758e+03,  2.9053e+03],
          ...,
          [ 1.0454e+03,  3.1514e+03],
          [-2.2407e+03,  2.7526e+03],
          [ 3.9179e+03, -3.5238e+02]],

         ...,

         [[ 2.2235e+03, -2.5056e+03],
          [ 1.8452e+02,  1.5053e+03],
          [ 5.9136e+00, -2.4585e+03],
          ...,
          [ 2.8552e+03, -1.0652e+03],
          [ 4.4578e+03, -3.1558e+03],
          [ 2.9

TypeError: The input should be complex (i.e. last dimension is 2).