In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os, sys
opj = os.path.join
from tqdm import tqdm
from functools import partial
import acd
from copy import deepcopy
sys.path.append('..')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
import transform_wrappers
sys.path.append('../dsets/mnist')
import dset
from model import Net, Net2c
from util import *
from torch import nn
from style import *
from captum.attr import *
from knockout import *
from attributions import *
import warnings
warnings.filterwarnings("ignore")
sys.path.append('../..')
from acd_wooseok.acd.scores import cd, score_funcs, cd_propagate

In [2]:
# set args
args = dset.get_args()
args.test_batch_size = 100

# load mnist data
train_loader, test_loader = dset.load_data(args.batch_size, args.test_batch_size, device)

# scores in fft space

In [3]:
# FFT
t = lambda x: torch.fft(torch.stack((x, torch.zeros_like(x)),dim=4) , 2)
transform_i = transform_wrappers.modularize(lambda x: torch.ifft(x, 2)[...,0])

# get interp scores
attr_methods = ['IG', 'DeepLift', 'SHAP', 'CD', 'InputXGradient']

In [None]:
# freq band
band_centers = np.linspace(0.15, 0.85, 10)
band_width = 0.05

# n_obs-by-num_centers
scores_fft = {
    'IG': [],
    'DeepLift': [],
    'SHAP': [],
    'CD': [],
    'InputXGradient': []
}    

for i, true_center in enumerate(band_centers):
    # load model
    model = Net2c().to(device)
    model.load_state_dict(torch.load(opj('models/freq','net2c_' + str(i) + '.pth'), map_location=device))
    
    # model augment
    m_t = transform_wrappers.Net_with_transform(model=model, transform=transform_i).to(device)
    m_t.eval()    
    
    scores = {
        'IG': [],
        'DeepLift': [],
        'SHAP': [],
        'CD': [],
        'InputXGradient': []
    }    
    
    for j, (data, _) in enumerate(test_loader):
        x = data.to(device)
        x_t = t(x).to(device)  
        results = get_attributions(x_t, m_t, class_num=1)    
        print('\rIterations =', i, j, end='')
        for name in attr_methods:
            scores[name].append(results[name])  
        # save
        pkl.dump(scores, open('results/scores_mnist.pkl','wb'))
    # convert to np.array
    for name in attr_methods:
        scores[name] = np.vstack(scores[name])
    
    for k, band_center in enumerate(band_centers):
        mask_bandpass = ifftshift(freq_band(n=28, band_center=band_center, band_width=band_width)) 
        for name in attr_methods:
            scores_fft[name].append((scores[name] * mask_bandpass[np.newaxis,]).sum(axis=(1,2)))               
    # save
    pkl.dump(scores_fft, open('results/scores_fft_mnist.pkl','wb'))

Iterations = 0 36

In [None]:
# viz scores in fft space
obs_idx = 0
plt.figure(figsize=(16, 5),dpi=100)
for i, name in enumerate(attr_methods):
    interp_scores = fftshift(results[name][obs_idx])
    plt.subplot(1, 5, i + 1)
    plt.imshow(interp_scores, cmap='RdBu')
    plt.title(name)
    plt.axis('off')
plt.tight_layout()
plt.show()