In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
from tqdm import tqdm

from ex_cosmology import p
from dset import get_dataloader, load_pretrained_model

# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_2dfilts, get_wavefun
from wave_attributions import Attributer
from visualize import cshow, plot_2dfilts, plot_2dreconstruct

# load data and model

In [3]:
# load data and model
train_loader, test_loader = get_dataloader(p.data_path, 
                                           img_size=p.img_size[2],
                                           split_train_test=True,
                                           batch_size=p.batch_size)  

model = load_pretrained_model(model_name='resnet18', device=device, data_path=p.model_path)    

# define wavelet

In [None]:
# wavelet transform 
wt = DWT2d(wave=p.wave, mode='zero', J=p.J, init_factor=p.init_factor, noise_factor=p.noise_factor).to(device)

# get image
data = iter(test_loader).next()[0].to(device)
data_t = wt(data)
recon = wt.inverse(data_t)

print("Reconstruction error={:.5f}".format(torch.norm(recon - data)**2/data.size(0)))

# get 2d wavelet filters
filt = get_2dfilts(wt)

In [None]:
# plot original and reconstruction images
plot_2dreconstruct(data, recon)

In [None]:
# original wavelet
phi, psi, x = get_wavefun(wt)

plot_1dfilts(filt[0], is_title=True, figsize=(2,2))

# plot wavelet filters
plot_2dfilts(filt[1], figsize=(5,5))

plot_wavefun((phi, psi, x), is_title=True, figsize=(3,1))

# optimize filter

In [None]:
# train
params = list(wt.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)
loss_f = get_loss_f(lamlSum=1, lamhSum=1, lamL2norm=1, lamCMF=1, lamConv=1, lamL1wave=0.1, lamL1attr=0.01)
trainer = Trainer(model, wt, optimizer, loss_f, target=1, 
                  use_residuals=True, attr_methods='Saliency', device=device, n_print=1)

In [None]:
trainer(train_loader, epochs=2)

In [None]:
plt.plot(np.log(trainer.train_losses))
plt.xlabel("epochs")
plt.ylabel("log train loss")
plt.title('Log-train loss vs epochs')
plt.show()

In [None]:
data_t = wt(data)
recon = wt.inverse(data_t)

print("Reconstruction error={:.5f}".format(torch.norm(recon - data)**2/data.size(0)))

# get 2d wavelet filters
filt = get_2dfilts(wt)

In [None]:
# plot original and reconstruction images
plot_2dreconstruct(data, recon)

# test error

In [None]:
# validator 
validator = Validator(model, test_loader)
rec_loss, lsum_loss, hsum_loss, L2norm_loss, CMF_loss, conv_loss, L1wave_loss, L1saliency_loss, L1inputxgrad_loss = validator(wt, target=1)

# original wavelet transform
wt_o = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0).to(device)
rec_loss_o, lsum_loss_o, hsum_loss_o, L2norm_loss_o, CMF_loss_o, conv_loss_o, L1wave_loss_o, L1saliency_loss_o, L1inputxgrad_loss_o = validator(wt_o, target=0)

print("\n\n \t Original filter:Reconstruction Error={:.5f} lsum loss={:.5f} hsum loss={:.5f} L2norm loss={:.5f} CMF loss={:.5f} conv loss={:.5f} L1wave loss={:.5f} L1saliency loss={:.5f} L1inputxgrad loss={:.5f} \n \
        Adaptive filter:Reconstruction Error={:.5f} lsum loss={:.5f} hsum loss={:.5f} L2norm loss={:.5f} CMF loss={:.5f} conv loss={:.5f} L1wave loss={:.5f} L1saliency loss={:.5f} L1inputxgrad loss={:.5f}"\
          .format(rec_loss_o, lsum_loss_o, hsum_loss_o, L2norm_loss_o, CMF_loss_o, conv_loss_o, L1wave_loss_o, L1saliency_loss_o, L1inputxgrad_loss_o, 
                  rec_loss, lsum_loss, hsum_loss, L2norm_loss, CMF_loss, conv_loss, L1wave_loss, L1saliency_loss, L1inputxgrad_loss))

In [None]:
# original wavelet
filt = get_2dfilts(wt)
filt_o = get_2dfilts(wt_o)
phi, psi, x = get_wavefun(wt)
phi_o, psi_o, x_o = get_wavefun(wt_o)

plot_1dfilts(filt[0], is_title=True, figsize=(2,2))
plot_1dfilts(filt_o[0], is_title=True, figsize=(2,2))

plot_2dfilts(filt[1], is_title=True, figsize=(2,2))
plot_2dfilts(filt_o[1], is_title=True, figsize=(2,2))

plot_wavefun((phi, psi, x), is_title=True, figsize=(3,1))
plot_wavefun((phi_o, psi_o, x_o), is_title=True, figsize=(3,1))