<a href="https://colab.research.google.com/github/Scurrra/WaveletNN-PyTorch/blob/master/notebooks/waveletnn_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WaveletNN Training

Example of training 1D orthonormal and biorthogonal wavelets with [PyWavelets](https://pywavelets.readthedocs.io/) generated analysis, i.e. learning an existing wavelet.

## Install `waveletnn` and `pywt` packages.

In [1]:
!pip install -qq waveletnn
!pip install -qq PyWavelets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m62.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
import numpy as np

from waveletnn import *
import pywt

## Adopt `pywt` analysis to match those of `waveletnn`

In [3]:
def analyze(signal, wavelet, padding: int, mode:str="antireflect", levels:int=1):
    assert mode in pywt.Modes.modes
    signals, details = [], []

    for _ in range(levels):
        signal = pywt.pad(signal, (padding,), mode)
        signal, detail = pywt.dwt(signal, wavelet, mode="zero")

        signals.append(torch.from_numpy(signal[(padding):(-padding)]))
        details.append(torch.from_numpy(detail[(padding):(-padding)]))

        signal = signal[(padding):(-padding)]

    return signals, details

In [4]:
def synthesize(packet, wavelet, padding: int, levels:int=1):
    signal, details = packet
    z = pywt.idwt(
        pywt.pad(signal[-1], (padding,padding), "zero"),
        pywt.pad(details[-1], (padding,padding), "zero"),
        wavelet, mode="zero"
    )
    for i in range(levels-2, -1, -1):
        z = pywt.idwt(
            pywt.pad(z[padding:(-padding)], (padding,padding), "zero"),
            pywt.pad(details[i], (padding,padding), "zero"),
            wavelet, mode="zero"
        )
    return z[padding:(-padding)]

In [5]:
mode = "antireflect"
wavelet = "db8"
# "db8" is of length 16 so the padding is 7, this will be shown later
padding = 7

In [6]:
x = np.random.rand(512)

In [7]:
signal, details = analyze(x, wavelet, padding, levels=3, mode=mode)
[len(signal[i]) for i in range(3)], [len(details[i]) for i in range(3)]

([256, 128, 64], [256, 128, 64])

In [8]:
z = synthesize(
    analyze(x, wavelet, padding, levels=3, mode=mode),
    wavelet, padding, levels=3
)
len(z)

512

In [9]:
# not perfect because of need to trim pywt output
sum((x-z) ** 2)

15.86329858816335

## Closer integration with `waeletnn` and bath support


In [10]:
def analyze_batch(signal, wavelet, padding: int, mode:str="antireflect", levels:int=1):
    pad = PadSequence(padding, padding, mode)

    signals, details = [], []

    for _ in range(levels):
        signal = pad(signal).numpy()
        signal, detail = pywt.dwt(signal, wavelet, mode="zero")

        signals.append(torch.from_numpy(signal[..., (padding):(-padding)]))
        details.append(torch.from_numpy(detail[..., (padding):(-padding)]))

        signal = signals[-1]

    return signals, details

In [11]:
def synthesize_batch(packet, wavelet, padding: int, levels:int=1):
    pad = PadSequence(padding, padding, "constant")

    signal, details = packet
    z = torch.from_numpy(pywt.idwt(
        pad(signal[-1]).numpy(),
        pad(details[-1]).numpy(),
        wavelet, mode="zero"
    ))
    for i in range(levels-2, -1, -1):
        z = torch.from_numpy(pywt.idwt(
            pad(z[..., padding:(-padding)]).numpy(),
            pad(details[i]).numpy(),
            wavelet, mode="zero"
        ))
    return z[..., padding:(-padding)]

In [12]:
x = torch.randn(5,1,512)
x.shape

torch.Size([5, 1, 512])

In [13]:
signal, details = analyze_batch(x, wavelet, padding, levels=3, mode=mode)
[signal[i].shape for i in range(3)], [details[i].shape for i in range(3)]

([torch.Size([5, 1, 256]), torch.Size([5, 1, 128]), torch.Size([5, 1, 64])],
 [torch.Size([5, 1, 256]), torch.Size([5, 1, 128]), torch.Size([5, 1, 64])])

In [14]:
z = synthesize_batch(
    analyze_batch(x, wavelet, padding, levels=3, mode=mode),
    wavelet, padding, levels=3
)
z.shape

torch.Size([5, 1, 512])

In [15]:
((x-z) ** 2).sum()

tensor(79.7877)

## Train and Test functions

In [16]:
def train_loop(model, wavelet, loss_fn, loss_reg, optimizer, batch_size=8, num_of_steps=500, log_step=100):
    model.train()
    for batch in range(num_of_steps):
        # Compute prediction and loss
        X = torch.rand(batch_size, 1, 512)
        (Ys, Yd), (h, g) = model(X, return_filters=True)
        (Zs, Zd) = analyze_batch(X, wavelet, padding=model.padding, levels=model.levels)

        loss = loss_reg(h, g)
        for i in range(model.levels):
            loss += loss_fn(Ys[i], Zs[i])
            loss += loss_fn(Yd[i], Zd[i])

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % log_step == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}  [{batch:>5d}/{num_of_steps:>5d}]")


def test_loop(model, wavelet, loss_fn, loss_reg, batch_size=8, n_tests=100):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    test_loss_fn, test_loss_reg = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for _ in range(n_tests):
            X = torch.rand(batch_size, 1, 512)
            (Ys, Yd), (h, g) = model(X, return_filters=True)
            (Zs, Zd) = analyze_batch(X, wavelet, padding=model.padding, levels=model.levels)

            test_loss_reg += loss_reg(h, g).item()
            for i in range(model.levels):
                test_loss_fn += loss_fn(Ys[i], Zs[i])
                test_loss_fn += loss_fn(Yd[i], Zd[i])

    test_loss_fn /= n_tests
    test_loss_reg /= n_tests
    print(f"Test Error: Avg MSE loss: {test_loss_fn:>8f}, Wavelet loss: {test_loss_reg:>8f} \n")

## Train orthonormal wavelet

In [17]:
wl_db9 = pywt.Wavelet("db9")
wl_db9.dec_len

18

In [18]:
wavelet_h = wl_db9.dec_lo.copy()
wavelet_h.reverse()
wavelet_h = torch.tensor(wavelet_h, dtype=torch.get_default_dtype())
wavelet_h

tensor([ 3.8078e-02,  2.4383e-01,  6.0482e-01,  6.5729e-01,  1.3320e-01,
        -2.9327e-01, -9.6841e-02,  1.4854e-01,  3.0726e-02, -6.7633e-02,
         2.5095e-04,  2.2362e-02, -4.7232e-03, -4.2815e-03,  1.8476e-03,
         2.3039e-04, -2.5196e-04,  3.9347e-05])

In [19]:
wavelet_g = wl_db9.dec_hi.copy()
wavelet_g.reverse()
wavelet_g = torch.tensor(wavelet_g, dtype=torch.get_default_dtype())
wavelet_g

tensor([ 3.9347e-05,  2.5196e-04,  2.3039e-04, -1.8476e-03, -4.2815e-03,
         4.7232e-03,  2.2362e-02, -2.5095e-04, -6.7633e-02, -3.0726e-02,
         1.4854e-01,  9.6841e-02, -2.9327e-01, -1.3320e-01,  6.5729e-01,
        -6.0482e-01,  2.4383e-01, -3.8078e-02])

In [20]:
model = OrthonormalWaveletBlock1D(wl_db9.dec_len, levels=3)
print(model)

OrthonormalWaveletBlock1D(
  (pad): PadSequence()
)


In [21]:
loss_fn = torch.nn.MSELoss()
loss_reg = OrthonormalWaveletRegularization(n_moments=lambda _: 8) # db9 is known to have 8 vanishing moments

In [22]:
for p in range(1, 8+1):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    loss_reg_p = OrthonormalWaveletRegularization(n_moments=lambda _, max_p=p: max_p)
    for t in range(20):
        print(f"Epoch {(p-1)*20+t+1}\n-------------------------------")
        train_loop(model, wl_db9, loss_fn, loss_reg_p, optimizer, batch_size=8)
        test_loop(model, wl_db9, loss_fn, loss_reg_p)
        scheduler.step()
print("Done!")

Epoch 1
-------------------------------
loss: 10.752133  [    0/  500]
loss: 2.902512  [  100/  500]
loss: 2.642475  [  200/  500]
loss: 4.660428  [  300/  500]
loss: 1.621223  [  400/  500]
Test Error: Avg MSE loss: 1.303443, Wavelet loss: 0.119339 

Epoch 2
-------------------------------
loss: 1.223843  [    0/  500]
loss: 1.558472  [  100/  500]
loss: 1.339987  [  200/  500]
loss: 0.919167  [  300/  500]
loss: 1.171949  [  400/  500]
Test Error: Avg MSE loss: 0.837195, Wavelet loss: 0.081192 

Epoch 3
-------------------------------
loss: 0.999043  [    0/  500]
loss: 0.805553  [  100/  500]
loss: 0.830221  [  200/  500]
loss: 0.604362  [  300/  500]
loss: 0.597777  [  400/  500]
Test Error: Avg MSE loss: 0.502033, Wavelet loss: 0.037185 

Epoch 4
-------------------------------
loss: 0.736610  [    0/  500]
loss: 0.461653  [  100/  500]
loss: 0.318172  [  200/  500]
loss: 0.525857  [  300/  500]
loss: 0.228041  [  400/  500]
Test Error: Avg MSE loss: 0.183256, Wavelet loss: 0.0273

In [23]:
X = torch.rand(10, 1, 512)
(Ys, Yd), (h, g) = model(X, return_filters=True)
(Zs, Zd) = analyze_batch(X, wl_db9, padding=model.padding, levels=model.levels)

In [24]:
loss = loss_reg(h, g)
print(loss)
for i in range(model.levels):
    loss += loss_fn(Yd[i], Zd[i])
    loss += loss_fn(Ys[i], Zs[i])
    print(loss)

tensor(2.1434e-09, grad_fn=<AddBackward0>)
tensor(2.1950e-09, grad_fn=<AddBackward0>)
tensor(2.3903e-09, grad_fn=<AddBackward0>)
tensor(3.1452e-09, grad_fn=<AddBackward0>)


In [25]:
loss_fn(h.detach(), wavelet_h)

tensor(9.2117e-12)

In [26]:
h.detach() - wavelet_h

tensor([ 1.4156e-07, -2.2352e-07, -4.1723e-07,  6.5565e-07,  2.0862e-07,
        -6.5565e-07, -2.2352e-07, -5.0664e-07, -4.0978e-08, -8.9407e-08,
         7.4413e-07, -3.0175e-07, -3.8184e-08,  7.0641e-07, -6.2527e-07,
         5.2986e-06, -2.1083e-07, -1.1606e-05])

In [27]:
g.detach() - wavelet_g

tensor([-1.1606e-05,  2.1083e-07,  5.2986e-06,  6.2527e-07,  7.0641e-07,
         3.8184e-08, -3.0175e-07, -7.4413e-07, -8.9407e-08,  4.0978e-08,
        -5.0664e-07,  2.2352e-07, -6.5565e-07, -2.0862e-07,  6.5565e-07,
         4.1723e-07, -2.2352e-07, -1.4156e-07])

In [28]:
r = torch.arange(len(g), dtype=torch.get_default_dtype())
pNorm=lambda p: np.sqrt(2) / (2 ** p)
n_moments=lambda g: 8
[
    (torch.dot(r ** (p-1), g.detach()) * pNorm(p)) ** 2
    for p in range(1, n_moments(g))
]

[tensor(1.9771e-11),
 tensor(4.8033e-12),
 tensor(2.9104e-11),
 tensor(1.4261e-09),
 tensor(4.6566e-10),
 tensor(0.),
 tensor(0.)]

In [29]:
[
    (torch.dot(r ** (p-1), wavelet_g) * pNorm(p)) ** 2
    for p in range(1, n_moments(wavelet_g))
]

[tensor(4.4409e-16),
 tensor(0.),
 tensor(4.5475e-13),
 tensor(2.6193e-10),
 tensor(1.8626e-09),
 tensor(2.9802e-08),
 tensor(1.9073e-06)]

In [30]:
inv_model = InverseWaveletBlock1D(model.kernel_size, levels=3, scaling_kernel=h, wavelet_kernel=g)
inv_model

InverseWaveletBlock1D(
  (pad): PadSequence()
)

In [31]:
XX = inv_model(Ys[-1], Yd)
XX.shape

torch.Size([10, 1, 512])

In [32]:
((X - XX) ** 2).sum()

tensor(99.9788, grad_fn=<SumBackward0>)

## Train biorthogonal wavelet

In [33]:
wl_bior68 = pywt.Wavelet("bior6.8")
wl_bior68.dec_len

18

In [34]:
wavelet_h = wl_bior68.dec_lo.copy()
wavelet_h.reverse()
wavelet_h = torch.tensor(wavelet_h, dtype=torch.get_default_dtype())
wavelet_h

tensor([ 0.0019, -0.0019, -0.0170,  0.0119,  0.0497, -0.0773, -0.0941,  0.4208,
         0.8259,  0.4208, -0.0941, -0.0773,  0.0497,  0.0119, -0.0170, -0.0019,
         0.0019,  0.0000])

In [35]:
wavelet_g = wl_bior68.dec_hi.copy()
wavelet_g.reverse()
wavelet_g = torch.tensor(wavelet_g, dtype=torch.get_default_dtype())
wavelet_g

tensor([ 0.0000, -0.0000,  0.0000, -0.0000,  0.0144, -0.0145, -0.0787,  0.0404,
         0.4178, -0.7589,  0.4178,  0.0404, -0.0787, -0.0145,  0.0144, -0.0000,
         0.0000, -0.0000])

In [36]:
model = BiorthogonalWaveletBlock1D(wl_bior68.dec_len, levels=3)
print(model)

BiorthogonalWaveletBlock1D(
  (pad): PadSequence()
)


In [37]:
loss_fn = torch.nn.MSELoss()
loss_reg = BiorthogonalWaveletRegularization(n_moments=lambda _: model.kernel_size // 3)

In [38]:
for p in range(1, model.kernel_size // 3 + 1):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    loss_reg_p = BiorthogonalWaveletRegularization(n_moments=lambda _, max_p=p: max_p)
    for t in range(20):
        print(f"Epoch {(p-1)*20+t+1}\n-------------------------------")
        train_loop(model, wl_bior68, loss_fn, loss_reg_p, optimizer, batch_size=8)
        test_loop(model, wl_bior68, loss_fn, loss_reg_p)
        scheduler.step()
print("Done!")

Epoch 1
-------------------------------
loss: 23.081833  [    0/  500]
loss: 1.026341  [  100/  500]
loss: 0.812883  [  200/  500]
loss: 0.662254  [  300/  500]
loss: 0.424941  [  400/  500]
Test Error: Avg MSE loss: 0.188710, Wavelet loss: 0.000822 

Epoch 2
-------------------------------
loss: 0.189314  [    0/  500]
loss: 0.062022  [  100/  500]
loss: 0.019098  [  200/  500]
loss: 0.007316  [  300/  500]
loss: 0.002492  [  400/  500]
Test Error: Avg MSE loss: 0.000814, Wavelet loss: 0.000001 

Epoch 3
-------------------------------
loss: 0.000780  [    0/  500]
loss: 0.000235  [  100/  500]
loss: 0.000064  [  200/  500]
loss: 0.000014  [  300/  500]
loss: 0.000003  [  400/  500]
Test Error: Avg MSE loss: 0.000000, Wavelet loss: 0.000000 

Epoch 4
-------------------------------
loss: 0.000000  [    0/  500]
loss: 0.000000  [  100/  500]
loss: 0.000000  [  200/  500]
loss: 0.000000  [  300/  500]
loss: 0.000000  [  400/  500]
Test Error: Avg MSE loss: 0.000000, Wavelet loss: 0.0000

In [39]:
X = torch.rand(10, 1, 512)
(Ys, Yd), (h, g) = model(X, return_filters=True)
(Zs, Zd) = analyze_batch(X, wl_bior68, padding=model.padding, levels=model.levels)

In [40]:
loss = loss_reg(h, g)
print(loss)
for i in range(model.levels):
    loss += loss_fn(Yd[i], Zd[i])
    loss += loss_fn(Ys[i], Zs[i])
    print(loss)

tensor(1.0492e-05, grad_fn=<AddBackward0>)
tensor(1.0492e-05, grad_fn=<AddBackward0>)
tensor(1.0492e-05, grad_fn=<AddBackward0>)
tensor(1.0492e-05, grad_fn=<AddBackward0>)


In [41]:
loss_fn(h[0].detach(), wavelet_h)

tensor(2.3983e-13)

In [42]:
h[0].detach() - wavelet_h

tensor([ 2.4040e-07, -1.5658e-07,  3.9116e-07, -1.0990e-07, -6.7055e-08,
        -3.5763e-07,  2.5332e-07,  5.9605e-08,  4.7684e-07,  1.4901e-07,
         4.3213e-07, -9.3132e-07, -3.5018e-07, -2.8964e-07,  2.7940e-07,
        -1.1884e-06,  3.5902e-07,  8.6026e-07])

In [43]:
g[0].detach() - wavelet_g

tensor([-1.9170e-06, -1.4400e-07,  2.2368e-06,  2.4637e-07,  8.0094e-08,
         3.6228e-07,  3.2783e-07, -1.1176e-07, -5.0664e-07,  2.3842e-07,
        -2.3842e-07,  9.6858e-08,  1.4901e-07,  1.6671e-07, -4.2841e-08,
         8.2682e-08,  4.4308e-08,  1.3144e-07])

In [44]:
r = torch.arange(len(g[0]), dtype=torch.get_default_dtype())
pNorm=lambda p: np.sqrt(2) / (2 ** p)
n_moments=lambda g: len(g) // 3
[
    (torch.dot(r ** (p-1), g[0].detach()) * pNorm(p)) ** 2
    for p in range(1, n_moments(g[0]))
]

[tensor(7.8337e-13),
 tensor(1.9213e-11),
 tensor(3.8244e-10),
 tensor(1.8925e-08),
 tensor(1.0286e-06)]

In [45]:
[
    (torch.dot(r ** (p-1), wavelet_g) * pNorm(p)) ** 2
    for p in range(1, n_moments(wavelet_g))
]

[tensor(4.4409e-16),
 tensor(2.8422e-14),
 tensor(4.5475e-13),
 tensor(2.9104e-11),
 tensor(4.6566e-10)]

In [46]:
inv_model = InverseWaveletBlock1D(model.kernel_size, levels=3, scaling_kernel=h[1], wavelet_kernel=g[1])
inv_model

InverseWaveletBlock1D(
  (pad): PadSequence()
)

In [47]:
XX = inv_model(Ys[-1], Yd)
XX.shape

torch.Size([10, 1, 512])

In [48]:
((X - XX) ** 2).sum()

tensor(9.6925, grad_fn=<SumBackward0>)