In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
import pickle as pkl
from tqdm import tqdm

from sim_cosmology import p, load_dataloader_and_pretrained_model
p.data_path = '../../src/dsets/cosmology/data'
p.model_path = '../../src/dsets/cosmology/data'
# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform2d import DWT2d
from utils import get_1dfilts, get_2dfilts, get_wavefun
from wave_attributions import Attributer
from visualize import cshow, plot_1dfilts, plot_2dfilts, plot_2dreconstruct, plot_wavefun

In [2]:
sys.path.append('../../lib/trim')
from trim import TrimModel
from losses import _reconstruction_loss, _lsum_loss, _hsum_loss, _L2norm_loss, _CMF_loss, _conv_loss, _L1_wave_loss, _L1_attribution_loss
from utils import low_to_high
import torch.nn.functional as F

## load data and model

In [3]:
# get dataloader and model
(train_loader, test_loader), model = load_dataloader_and_pretrained_model(p, img_size=256)

In [16]:
# wavelet transform 
torch.manual_seed(7)
wt = DWT2d(wave='db5', mode='symmetric', J=5, init_factor=1, noise_factor=0.0).to(device)
mt = TrimModel(model, wt.inverse, use_residuals=True)    
attributer = Attributer(mt, attr_methods='Saliency', device=device)

# get image
data = iter(test_loader).next()[0].to(device)
data_t = wt(data)
recon_data = wt.inverse(data_t)
# with torch.backends.cudnn.flags(enabled=False):
#     attributions = attributer(data_t, target=1, additional_forward_args=deepcopy(data))

In [15]:
# loss = _reconstruction_loss(data, recon_data)
# loss = _lsum_loss(wt)
# loss = _hsum_loss(wt)
# loss = _L2norm_loss(wt)
# loss = _CMF_loss(wt)
# loss = _conv_loss(wt)
# loss = _L1_wave_loss(data_t) 
# loss = _L1_attribution_loss(attributions)
# print(loss.item())
# loss.backward()

4.6462245316180933e-17


## optimize filter

In [None]:
# train
params = list(wt.parameters())
optimizer = torch.optim.Adam(params, lr=0.01)

epochs = 50
train_losses = np.empty(epochs)
wt.train()

for epoch in range(epochs):
    epoch_loss = 0.
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        # zero grad
        optimizer.zero_grad()
        # transform
        data_t = wt(data)
        # reconstruction
        recon_data = wt.inverse(data_t)
        # loss
        with torch.backends.cudnn.flags(enabled=False):
            attributions = attributer(data_t, target=1, additional_forward_args=deepcopy(data))
        loss = _L1_attribution_loss(attributions)

        # backward
        loss.backward()
        # update step
        optimizer.step()
        epoch_loss += loss.item()   
        
    mean_epoch_loss = epoch_loss / (batch_idx + 1)
    train_losses[epoch] = mean_epoch_loss
    print('====> Epoch: {} Average train loss: {:.4f}'.format(epoch, mean_epoch_loss))


In [None]:
# test
wt.eval()
for batch_idx, (data, _) in enumerate(test_loader):
    data = data.to(device)
    # zero grad
    optimizer.zero_grad()
    # transform
    data_t = wt(data)
    # reconstruction
    recon_data = wt.inverse(data_t)
    # loss
    with torch.backends.cudnn.flags(enabled=False):
        attributions = attributer(data_t, target=1, additional_forward_args=deepcopy(data))
    loss = _sum_loss(wt)

    # backward
    loss.backward()
    # update step
    optimizer.step()
    epoch_loss += loss.item()   
        
mean_epoch_loss = epoch_loss / (batch_idx + 1)
print('====> Average test loss: {:.4f}'.format(mean_epoch_loss))


In [None]:
plt.plot(np.log(train_losses))