In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch

import os,sys
from copy import deepcopy
import pickle as pkl
from tqdm import tqdm
import acd
from acd.scores import cd_propagate
import sim_biology
from sim_biology import p, load_dataloader_and_pretrained_model
opj = os.path.join

# p.data_path = "../../src/dsets/biology/data"
p.data_path = "data"
p.model_path = "models"
sys.path.append('../../src')
sys.path.append('../../src/adaptive_wavelets')
sys.path.append('../../lib/trim')
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_1dfilts
from wave_attributions import Attributer
from visualize import cshow, plot_1dreconstruct, plot_1dfilts
from trim import TrimModel

device = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'

# load data and model

In [2]:
# get dataloader and model
model = sim_biology.load_model(p)
# (train_loader, test_loader), model = load_dataloader_and_pretrained_model(p)

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0.0).to(device)

# get a batch of images
data = torch.randn(100, 1, 40)
# data = iter(test_loader).next()[0].to(device)
# wavelet transform and reconstruction
data_t = wt(data)
recon = wt.inverse(data_t)

In [3]:
print('shape of data / recon', data.shape, recon.shape)
for i in range(len(data_t)):
    print('shape of wavelet coeffs', data_t[i].shape)

shape of data / recon torch.Size([100, 1, 40]) torch.Size([100, 1, 40])
shape of wavelet coeffs torch.Size([100, 1, 10])
shape of wavelet coeffs torch.Size([100, 1, 24])
shape of wavelet coeffs torch.Size([100, 1, 16])
shape of wavelet coeffs torch.Size([100, 1, 12])
shape of wavelet coeffs torch.Size([100, 1, 10])


In [4]:
print('reconstruction error: {:.5f}'.format(torch.norm(data - recon).item()/data.size(0)))

reconstruction error: 0.00000


In [5]:
# trim model
mt = TrimModel(model, wt.inverse, use_residuals=True).to(device)
out = mt(data_t, deepcopy(data))
print(out[:5])

tensor([[-0.0134],
        [-0.7077],
        [-0.4140],
        [-0.4139],
        [-0.5116]], grad_fn=<SliceBackward>)


# calculate cd score in original domain

In [6]:
x = torch.ones_like(recon[0:1]) # starts (batch_size, num_channels, seq_len)
x.data = recon[0:1]
x = x.permute(0, 2, 1) # should be (batch_size, seq_len, num_channels)

In [7]:
with torch.no_grad():
    rel, irrel = cd_propagate.propagate_lstm(x, model.model.lstm, start=0, stop=30, my_device='cpu')
rel = rel.squeeze(1)
irrel = irrel.squeeze(1)
rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
rel + irrel

tensor([[-0.0134]])

# compute cd score in wavelet domain

In [8]:
data_t = wt(data) # select some coefs of data_t to interpret (so maybe zero all but one coefs)
rel = wt.inverse(data_t)
irrel = data - rel

In [9]:
# rel, irrel (residuals shoul be put into irrel)
with torch.no_grad():
    rel, irrel = cd_propagate.propagate_lstm_block(x_rel=x, x_irrel=x-2,
                                                   module=model.model.lstm,
                                                   start=0, stop=2, my_device='cpu')
rel = rel.squeeze(1)
irrel = irrel.squeeze(1)
rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
print('cd score', rel)

cd score tensor([[0.2571]])
