In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
import pickle as pkl

sys.path.append('../../src')
sys.path.append('../../src/vae/models')
sys.path.append('../../src/dsets/cosmology')
from dset import get_dataloader
from losses import _reconstruction_loss
from viz import viz_im_r, cshow, viz_filters
from sim_cosmology import p, load_dataloader_and_pretrained_model
from captum.attr import *

from wavelet_transform import DTCWT_Transform, DTCWT_Mask, tuple_Attributer, TrimModel

## load data and model

In [6]:
# get dataloader and model
train_loader, model = load_dataloader_and_pretrained_model(p)

# set up trim model and wavelet transform
# wavelet transform
wt = DTCWT_Transform(J=5)

# prepend transformation onto network
def transform_i(x, x_orig, indx=0):
    '''
    DTCWT inverse transform 
    -----------------------
    x : torch.Tensor
    x_orig : tuple
            detached original input 
    '''
    x_orig = list(x_orig)
    x_orig[indx] = x    
    return wt.inverse(x_orig)

model_t = TrimModel(model, transform_i)

## sort in decreasing order

In [10]:
# input
torch.manual_seed(p.seed)
im = iter(train_loader).next()[0][0:2].to(device)
im.requires_grad = True

# wavelet transform
im_w = wt(im)

# interp score
attributer = tuple_Attributer(model_t, attr_methods='IntegratedGradients')
attributions = attributer(im_w, target=1)

In [None]:
# sparsity level
num_sp = 20
sp_level = np.geomspace(1, 262144, num_sp).astype(np.int)

# viz
plt.figure(figsize=(25,25))
vmax=0.15
vmin=-0.05
n_row = 2
n_col = 10
p = 256 + 2
mosaic = np.zeros((p*n_row,p*n_col))
indx = 0
for i in range(n_row):
    for j in range(n_col):
        # sort attribution
        b = torch.tensor([])
        list_of_size = [0]

        for k in range(6):
            a = attribution[k].cpu().reshape(-1)
            b = torch.cat((b,a))
            list_of_size.append(list_of_size[-1] + a.shape[0])
        sort_order = torch.argsort(b, descending=True) 
        m = torch.zeros_like(b)
        m[sort_order[:sp_level[indx]]] = 1

        list_of_m = []
        for k in range(6):
            n0 = list_of_size[k]
            n1 = list_of_size[k+1]
            list_of_m.append(m[n0:n1].reshape(im_w[k].shape))

        wm_list = []
        for k in range(6):
            wm_list.append(torch.mul(list_of_m[k].to(device), im_w[k]))
        wm_list = tuple(wm_list)

        im_ = ifm(wm_list).squeeze().data.cpu().numpy()        
        mosaic[i*p:(i+1)*p,j*p:(j+1)*p] = np.pad(im_,(1,1),mode='constant')
        indx += 1
plt.title("varying sparsity penalty on wavelet coefficients")
plt.imshow(mosaic, cmap='magma', vmax=vmax, vmin=vmin)
plt.axis('off')    
plt.show()         
        

In [None]:
res = xfm(im)
rec = ifm(res)
viz_im_r(im[0,0], rec[0,0])

In [None]:
torch.norm(im[0,0] - rec[0,0])