In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
import pickle as pkl
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

from sim_simulation import p
sys.path.append('models')
from models import Feedforward
# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_1dfilts, get_wavefun
from wave_attributions import Attributer
from visualize import cshow, plot_1dfilts, plot_1dreconstruct

sys.path.append('../../lib/trim')
from trim import TrimModel
import acd
from acd.scores import cd_propagate
from acd.scores import cd

In [3]:
# seed
np.random.seed(p.seed)
torch.manual_seed(p.seed)

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0)
# freeze layers
for param in wt.parameters():
    param.requires_grad = False  

n = 55000
d = 64
X = torch.randn(n, 1, d)
X_t = wt(X)
# re = wt(torch.randn(1, 1, p.d).to(device))
# data_t = ()
# for scale in range(len(re)):
#     dim = re[scale].size(2)
#     pr = torch.bernoulli(0.1 * torch.ones(p.n, 1, dim))
#     rn = abs(torch.randn(p.n, 1, dim))
#     data_t += ((pr * rn).to(device),)
# data = wt.inverse(data_t)

# generate y
beta = ()
for i in range(len(X_t)):
    beta += (torch.zeros_like(X_t[i][0:1,...]),)
beta[p.scale_knockout][..., p.idx_knockout - p.window: p.idx_knockout + p.window + 1] = 2.0

y = 0
for x, b in zip(X_t, beta):
    y += torch.matmul(x.squeeze(), b.squeeze())
eps = 0.1 * torch.randn_like(y)
y = y + eps

create data

In [4]:
X_train = X[:50000]
y_train = y[:50000]
X_test = X[50000:]
y_test = y[50000:]

train_loader = DataLoader(TensorDataset(X_train, y_train), 
                          batch_size=100,
                          shuffle=True) 
test_loader = DataLoader(TensorDataset(X_test, y_test), 
                          batch_size=100,
                          shuffle=False) 

# load model
model = Feedforward(input_size=d).to(device)
model.load_state_dict(torch.load('models/FFN.pth'))
model.eval()
for param in model.parameters():
    param.requires_grad = False  

# load data and model

In [5]:
# get image
data = iter(test_loader).next()[0].to(device)

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0.0).to(device)

data_t = wt(data)
recon = wt.inverse(data_t)

print("Reconstruction error={:.5f}".format(torch.norm(recon - data)**2/data.size(0)))

Reconstruction error=0.00000


In [6]:
print('shape of data / recon', data.shape, recon.shape)
for i in range(len(data_t)):
    print('shape of wavelet coeffs', data_t[i].shape)

shape of data / recon torch.Size([100, 1, 64]) torch.Size([100, 1, 64])
shape of wavelet coeffs torch.Size([100, 1, 12])
shape of wavelet coeffs torch.Size([100, 1, 36])
shape of wavelet coeffs torch.Size([100, 1, 22])
shape of wavelet coeffs torch.Size([100, 1, 15])
shape of wavelet coeffs torch.Size([100, 1, 12])


In [7]:
print('reconstruction error: {:.5f}'.format(torch.norm(data - recon).item()/data.size(0)))

reconstruction error: 0.00000


In [8]:
# trim model
mt = TrimModel(model, wt.inverse, use_residuals=True).to(device)
out = mt(data_t, deepcopy(data))
print(out[:5] - model(data)[:5])

tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]], device='cuda:0', grad_fn=<SubBackward0>)


# calculate cd score in original domain

In [9]:
data = data[:30]
mask = torch.zeros_like(data)
mask[:,0,0] = 1

In [1]:
rel, irrel = cd.cd(data, model, mask, model_type=None, device='cuda')

NameError: name 'cd' is not defined

In [4]:
cd.cd

<function acd.scores.cd.cd(im_torch:torch.Tensor, model, mask=None, model_type=None, device='cuda', transform=None)>

# compute cd score in wavelet domain

In [None]:
data_t = wt(data) # select some coefs of data_t to interpret (so maybe zero all but one coefs)
rel = wt.inverse(data_t)
irrel = data - rel