In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
import pickle as pkl
from tqdm import tqdm

from sim_cosmology import p, load_dataloader_and_pretrained_model
# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform2d import DWT2d
from utils import get_1dfilts, get_2dfilts
from wave_attributions import Attributer
from visualize import cshow, plot_1dfilts, plot_2dfilts, plot_2dreconstruct

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## load data and model

In [6]:
# get dataloader and model
(train_loader, test_loader), model = load_dataloader_and_pretrained_model(p, img_size=256)

# # check prediction
# with torch.no_grad():
#     result = {'y': [], 'pred': []}
#     for data, params in train_loader:
#         result['y'].append(params[:,1].detach().cpu())
#         result['pred'].append(model(data.to(device))[:,1].detach().cpu())
# plt.scatter(torch.cat(result['y']), torch.cat(result['pred']))
# plt.xlabel('true param')
# plt.ylabel('predicted param')
# plt.show()  

## initialize filter

In [None]:
import pywt

In [None]:
x = im[0,0,...]
x = x.detach().cpu().numpy()

In [None]:
cshow(x)

In [None]:
import numpy as np
import pywt
from matplotlib import pyplot as plt
from pywt._doc_utils import wavedec2_keys, draw_2d_wp_basis

shape = x.shape

max_lev = 3       # how many levels of decomposition to draw
label_levels = 3  # how many levels to explicitly label on the plots

fig, axes = plt.subplots(2, 4, figsize=[14, 8])
for level in range(0, max_lev + 1):
    if level == 0:
        # show the original image before decomposition
        axes[0, 0].set_axis_off()
        axes[1, 0].imshow(x, cmap=plt.cm.gray, vmax=0.15, vmin=-0.05)
        axes[1, 0].set_title('Image')
        axes[1, 0].set_axis_off()
        continue

    # plot subband boundaries of a standard DWT basis
    draw_2d_wp_basis(shape, wavedec2_keys(level), ax=axes[0, level],
                     label_levels=label_levels)
    axes[0, level].set_title('{} level\ndecomposition'.format(level))

    # compute the 2D DWT
    c = pywt.wavedec2(x, 'db2', mode='periodization', level=level)
    # normalize each coefficient array independently for better visibility
#     c[0] /= np.abs(c[0]).max()
#     for detail_level in range(level):
#         c[detail_level + 1] = [d/np.abs(d).max() for d in c[detail_level + 1]]
    # show the normalized coefficients
    arr, slices = pywt.coeffs_to_array(c)
    axes[1, level].imshow(arr, cmap=plt.cm.gray, vmax=0.15, vmin=-0.05)
    axes[1, level].set_title('Coefficients\n({} level)'.format(level))
    axes[1, level].set_axis_off()

plt.tight_layout()
plt.show()

In [None]:
# get image
torch.manual_seed(p.seed)
im = iter(test_loader).next()[0][0:64].to(device)

# wavelet transform 
wt = DWT2d(wave='db5', mode='symmetric', J=5, init_factor=0, noise_factor=0.1).to(device)

im_t = wt(im)
recon = wt.inverse(im_t)

print("Reconstruction error={:.5f}".format(torch.norm(recon - im)**2/im.size(0)))

# get 2d wavelet filters
filt = get_2dfilts(wt)

In [None]:
# plot original and reconstruction images
plot_2dreconstruct(im, recon)

In [None]:
# plot wavelet filters
plot_2dfilts(filt, figsize=(4,4), share_min_max=True)

## optimize filter

In [None]:
# train
params = list(wt.parameters())
optimizer = torch.optim.Adam(params, lr=0.005)
loss_f = get_loss_f(lamL1attr=10)
trainer = Trainer(model, wt, Attributer, optimizer, loss_f, attr_methods='Saliency', device=device)

In [None]:
trainer(train_loader, epochs=50)

In [None]:
plt.plot(np.log(trainer.train_losses))
plt.xlabel("epochs")
plt.ylabel("log train loss")
plt.title('Log-train loss vs epochs')
plt.show()

In [None]:
im_t = wt(im)
recon = wt.inverse(im_t)

print("Reconstruction error={:.5f}".format(torch.norm(recon - im)**2/im.size(0)))

# get 2d wavelet filters
filt = get_2dfilts(wt)

In [None]:
# plot original and reconstruction images
plot_2dreconstruct(im, recon)

In [None]:
# plot wavelet filters
plot_2dfilts(filt, figsize=(4,4), share_min_max=True)

## test error

In [None]:
loss_v = get_loss_f(lamL1attr=1)

# validator 
validator = Validator(model, wt, Attributer, loss_v, attr_methods='Saliency', device=device)
_, rec_loss, L1attr_loss = validator(test_loader)

# original wavelet transform
wt_o = DWT2d(wave='db5', mode='symmetric', J=5, init_factor=1, noise_factor=0).to(device)
validator_o = Validator(model, wt_o, Attributer, loss_v, attr_methods='Saliency', device=device)
_, rec_loss_o, L1attr_loss_o = validator_o(test_loader)

print("\n\n Original filter:Reconstruction Error={:.5f} L1attribution loss={:.5f} \n Adaptive filter:Reconstruction Error={:.5f} L1attribution loss={:.5f}"\
          .format(rec_loss_o, L1attr_loss_o, rec_loss, L1attr_loss))

In [None]:
plt.plot(wt.h1.data.squeeze().detach().cpu())
plt.plot(wt_o.h1.data.squeeze().detach().cpu())

In [None]:
plt.plot(wt.h0.data.squeeze().detach().cpu())
plt.plot(wt_o.h0.data.squeeze().detach().cpu())