In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
import pickle as pkl

sys.path.append('models')
from sim_cosmology import p, load_dataloader_and_pretrained_model
# wt modules
sys.path.append('../../src')
sys.path.append('../../src/adaptive_wavelets')
sys.path.append('../../src/dsets/cosmology')
from dset import get_dataloader
from losses import get_loss_f
from train import Trainer, Validator
from wavelet_transform import Wavelet_Transform, Attributer, get_2dfilts, initialize_filters
from utils import tuple_L1Loss, tuple_L2Loss, thresh_attrs
from viz import viz_im_r, cshow, viz_filters, viz_list

## load data and model

In [None]:
# get dataloader and model
(train_loader, test_loader), model = load_dataloader_and_pretrained_model(p, img_size=256)

# check prediction
# with torch.no_grad():
#     result = {'y': [], 'pred': []}
#     for data, params in train_loader:
#         result['y'].append(params[:,1].detach().cpu())
#         result['pred'].append(model(data.to(device))[:,1].detach().cpu())
# plt.scatter(torch.cat(result['y']), torch.cat(result['pred']))
# plt.xlabel('true param')
# plt.ylabel('predicted param')
# plt.show()  

## initialize filter

In [None]:
# get image
torch.manual_seed(p.seed)
im = iter(test_loader).next()[0][0:64].to(device)

# wavelet transform with initialization
wt_orig = Wavelet_Transform(wt_type='DWT', wave='db3', mode='symmetric', device='cuda', J=5)
viz_im_r(im[0], wt_orig.inverse(wt_orig(im))[0])
print("Recon={:.5f}".format(torch.norm(wt_orig.inverse(wt_orig(im)) - im)**2/im.size(0)))

filt = get_2dfilts(wt_orig)
viz_list(filt, figsize=(4,4))

In [None]:
wt = initialize_filters(wt_orig, init_level=1, noise_level=0.2)
filt = get_2dfilts(wt)
viz_im_r(im[0], wt.inverse(wt(im))[0])
print("Recon={:.5f}".format(torch.norm(wt.inverse(wt(im)) - im)**2/im.size(0)))

viz_list(filt, figsize=(4,4))

## optimize filter

In [None]:
# train
params = list(wt.xfm.parameters()) + list(wt.ifm.parameters())
optimizer = torch.optim.Adam(params, lr=0.01)
loss_f = get_loss_f(lamL1attr=50)
trainer = Trainer(model, wt, Attributer, optimizer, loss_f, attr_methods='Saliency', device=device)
trainer(train_loader, epochs=10)

In [None]:
trainer(train_loader, epochs=10)

In [None]:
filt = get_2dfilts(wt)
viz_im_r(im[0], wt.inverse(wt(im))[0])
print("Recon={:.5f}".format(torch.norm(wt.inverse(wt(im)) - im)**2/im.size(0)))

In [None]:
viz_list(filt, figsize=(4,4))

In [None]:
loss_v = get_loss_f(lamL1attr=1)
validator = Validator(model, wt, Attributer, loss_v, attr_methods='Saliency', device=device)
_, rec_loss, L1attr_loss = validator(test_loader)

print("\nRecon={:.5f} L1attr={:.5f}".format(rec_loss, L1attr_loss))

In [None]:
loss_v = get_loss_f(lamL1attr=1)
validator = Validator(model, wt_orig, Attributer, loss_v, attr_methods='Saliency', device=device)
_, rec_loss, L1attr_loss = validator(test_loader)

print("\nRecon={:.5f} L1attr={:.5f}".format(rec_loss, L1attr_loss))