In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
from tqdm import tqdm
import pickle as pkl
import pandas as pd
import torch.nn.functional as F

from ex_biology import p
from dset import get_dataloader, load_pretrained_model

# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_1dfilts, get_wavefun, low_to_high
from wave_attributions import Attributer
from visualize import cshow, plot_1dfilts, plot_1dreconstruct, plot_wavefun

# evaluation
from matplotlib import gridspec
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV, LogisticRegression
from sklearn import metrics
from sklearn.model_selection import cross_val_score
from feature_transform import max_transformer

# trim model
from trim import TrimModel

# cd
import acd
from acd.scores import cd_propagate

# load data and model

In [2]:
# load data and model
train_loader, test_loader = get_dataloader(p.data_path, 
                                           batch_size=p.batch_size,
                                           is_continuous=p.is_continuous)   

model = load_pretrained_model(p.model_path, device=device)    

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0.0).to(device)

# get a batch of images
data = iter(test_loader).next()[0]
data = data.to(device)

# wavelet transform and reconstruction
data_t = wt(data)
recon = wt.inverse(data_t)

In [3]:
print('shape of data / recon', data.shape, recon.shape)
for i in range(len(data_t)):
    print('shape of wavelet coeffs', data_t[i].shape)

shape of data / recon torch.Size([100, 1, 40]) torch.Size([100, 1, 40])
shape of wavelet coeffs torch.Size([100, 1, 10])
shape of wavelet coeffs torch.Size([100, 1, 24])
shape of wavelet coeffs torch.Size([100, 1, 16])
shape of wavelet coeffs torch.Size([100, 1, 12])
shape of wavelet coeffs torch.Size([100, 1, 10])


In [4]:
print('reconstruction error: {:.5f}'.format(torch.norm(data - recon).item()/data.size(0)))

reconstruction error: 0.00000


In [5]:
# trim model
mt = TrimModel(model, wt.inverse, use_residuals=True).to(device)
out = mt(data_t, deepcopy(data))
print(out[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0', grad_fn=<SliceBackward>)


In [6]:
print(model(data)[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0')


# calculate cd score in original domain

In [7]:
x = torch.ones_like(recon[0:1]) # starts (batch_size, num_channels, seq_len)
x.data = recon[0:1]
x = x.permute(0, 2, 1) # should be (batch_size, seq_len, num_channels)

In [9]:
with torch.no_grad():
    rel, irrel = cd_propagate.propagate_lstm(x, model.model.lstm, start=0, stop=30, my_device=device)
rel = rel.squeeze(1)
irrel = irrel.squeeze(1)
rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
rel + irrel

tensor([[-0.8557]], device='cuda:0')

# compute cd score in wavelet domain

In [24]:
data_t = wt(data) # select some coefs of data_t to interpret (so maybe zero all but one coefs)
rel = wt.inverse(data_t)
irrel = data - rel

In [26]:
# rel, irrel (residuals shoul be put into irrel)
rel = rel.permute(0, 2, 1)
irrel = irrel.permute(0, 2, 1) 
with torch.no_grad():
    rel, irrel = cd_propagate.propagate_lstm_block(x_rel=rel, x_irrel=irrel,
                                                   module=model.model.lstm,
                                                   start=0, stop=2, my_device=device)
rel = rel.squeeze(1)
irrel = irrel.squeeze(1)
rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
# print('cd score', rel)


In [27]:
rel + irrel - model(data)

tensor([[ 0.2010],
        [ 0.1898],
        [ 0.1877],
        [ 0.1776],
        [-0.0508],
        [-0.6380],
        [-0.2458],
        [-0.1799],
        [-0.0256],
        [ 0.2721],
        [-0.2880],
        [ 0.1929],
        [-0.3667],
        [-0.0306],
        [ 0.1335],
        [-0.1186],
        [-0.2300],
        [ 0.0196],
        [ 0.1616],
        [-0.1821],
        [ 0.0545],
        [-0.0046],
        [ 0.2458],
        [ 0.1000],
        [ 0.4037],
        [-0.4669],
        [-0.1043],
        [-0.1486],
        [ 0.1395],
        [ 0.1988],
        [ 0.2528],
        [-0.5685],
        [ 0.0834],
        [ 0.1693],
        [-0.0849],
        [ 0.0227],
        [-0.9211],
        [ 0.3313],
        [-0.9261],
        [-0.0167],
        [-0.7438],
        [-0.0384],
        [ 0.2715],
        [-0.0997],
        [-0.2468],
        [-0.5197],
        [ 0.2099],
        [-0.4587],
        [ 0.4068],
        [-0.5164],
        [-1.3669],
        [-0.2368],
        [ 0.