In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os, sys
opj = os.path.join
from tqdm import tqdm
from functools import partial
import acd
from copy import deepcopy
sys.path.append('..')
from transforms_torch import wavelet_filter
import transform_wrappers
sys.path.append('../dsets/mnist')
import dset
from model import Net, Net2c
from util import *
from torch import nn
from style import *
from pytorch_wavelets import DWTForward, DWTInverse
from captum.attr import *
from knockout import *
import warnings
warnings.filterwarnings("ignore")

In [2]:
# set args
args = dset.get_args()

# load mnist data
train_loader, test_loader = dset.load_data(args.batch_size, args.test_batch_size, device)

# load model

In [3]:
# wavelet transform
xfm = DWTForward(J=3, mode='symmetric', wave='db4')
ifm = DWTInverse(mode='symmetric', wave='db4')
t = lambda x: xfm(x)
transform_i = transform_wrappers.modularize(lambda x: ifm(x))
transformer = lambda x: wavelet_filter(x, t, transform_i, idx=2, p=0.5)

In [None]:
# load model
train_Net2c(train_loader, args, transformer, save_path=opj('models/wt','net2c_' + str(0) + '.pth'))
model = Net2c().to(device)
model.load_state_dict(torch.load(opj('models/wt','net2c_' + str(0) + '.pth'), map_location=device))

# test model
test_loss, correct = test_Net2c(test_loader, model, transformer)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
    test_loss, correct, 2*len(test_loader.dataset),
    100. * correct / (2*len(test_loader.dataset))))    




# scores in wt space

In [None]:
# wavelet transform
xfm = DWTForward(J=3, mode='symmetric', wave='db4').to(device)
ifm = DWTInverse(mode='symmetric', wave='db4').to(device)
t = lambda x: xfm(x)
transform_i = transform_wrappers.modularize(lambda x: ifm(x))

In [None]:
# test image
x, _ = iter(test_loader).next()
x = x.to(device)[0:1]
x_t = t(x)
print('Shape of wavelet coeffs\n', x_t[0].shape, x_t[1][0].shape, x_t[1][1].shape,x_t[1][2].shape)

In [None]:
def transform_i_re(x):
    x_tuple = (x[0], (x[1], x[2], x[3]))
    return transform_i(x_tuple)

In [None]:
# flatten tuples
a, (b, c, d) = x_t
a.requires_grad, b.requires_grad, c.requires_grad, d.requires_grad = True, True, True, True
x_t_re = (a, b, c, d)

# prepend transform onto network
m_t = transform_wrappers.Net_with_transform(model=model, transform=transform_i_re).to(device)
m_t.eval()

print('Difference of the model outputs', torch.norm(m_t(x_t_re) - model(x)).item())

In [None]:
output = m_t(x_t_re)[0][1]
output.backward()

In [None]:
scores = []
for i in range(len(x_t_re)):
    scores.append((x_t_re[i] * x_t_re[i].grad).squeeze())

In [None]:
plt.imshow(scores[0].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
plt.imshow(scores[1][0].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
plt.imshow(scores[1][1].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
plt.imshow(scores[1][2].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
plt.imshow(scores[3][0].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
plt.imshow(scores[3][1].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
plt.imshow(scores[3][2].cpu().detach())
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
# get interp scores
attr_methods = ['IG', 'DeepLift', 'SHAP', 'CD', 'InputXGradient']
name = 'IG'
func = [IntegratedGradients, DeepLift, GradientShap, None, InputXGradient][0]
attributer = func(m_t)
class_num = 1
attributions = attributer.attribute((deepcopy(x_t_re)),target=class_num)