In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
from tqdm import tqdm
import pickle as pkl
import pandas as pd
import torch.nn.functional as F

from ex_biology import p
from dset import get_dataloader, load_pretrained_model

# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_1dfilts, get_wavefun, low_to_high
from wave_attributions import Attributer
from visualize import cshow, plot_1dfilts, plot_1dreconstruct, plot_wavefun

# evaluation
from matplotlib import gridspec
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV, LogisticRegression
from sklearn import metrics
from sklearn.model_selection import cross_val_score
from feature_transform import max_transformer

# trim model
from trim import TrimModel

# cd
import acd
from acd.scores import cd_propagate

# load data and model

In [2]:
# load data and model
train_loader, test_loader = get_dataloader(p.data_path, 
                                           batch_size=p.batch_size,
                                           is_continuous=p.is_continuous)   

model = load_pretrained_model(p.model_path, device=device)    

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0.0).to(device)

# get a batch of images
data = iter(test_loader).next()[0]
data = data.to(device)

# wavelet transform and reconstruction
data_t = wt(data)
recon = wt.inverse(data_t)

In [3]:
print('shape of data / recon', data.shape, recon.shape)
for i in range(len(data_t)):
    print('shape of wavelet coeffs', data_t[i].shape)

shape of data / recon torch.Size([100, 1, 40]) torch.Size([100, 1, 40])
shape of wavelet coeffs torch.Size([100, 1, 10])
shape of wavelet coeffs torch.Size([100, 1, 24])
shape of wavelet coeffs torch.Size([100, 1, 16])
shape of wavelet coeffs torch.Size([100, 1, 12])
shape of wavelet coeffs torch.Size([100, 1, 10])


In [4]:
print('reconstruction error: {:.5f}'.format(torch.norm(data - recon).item()/data.size(0)))

reconstruction error: 0.00000


In [5]:
# trim model
mt = TrimModel(model, wt.inverse, use_residuals=True).to(device)
out = mt(data_t, deepcopy(data))
print(out[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0', grad_fn=<SliceBackward>)


In [6]:
print(model(data)[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0')


# calculate cd score in original domain

In [7]:
x = torch.ones_like(recon[0:1]) # starts (batch_size, num_channels, seq_len)
x.data = recon[0:1]
x = x.permute(0, 2, 1) # should be (batch_size, seq_len, num_channels)

In [8]:
with torch.no_grad():
    rel, irrel = cd_propagate.propagate_lstm(x, model.model.lstm, start=0, stop=30, my_device=device)
rel = rel.squeeze(1)
irrel = irrel.squeeze(1)
rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
rel + irrel

tensor([[-0.8557]], device='cuda:0')

# compute cd score in wavelet domain

In [9]:
def track_wave_rel(x, idx):
    l = len(x)
    dim = [0]
    tensor = []
    for i in range(l):
        d = x[i].size(2)
        dim.append(d)
        tensor.append(x[i])
    tensor = torch.cat(tensor, 2)
    rel = torch.zeros_like(tensor)
    rel[...,idx] = tensor[...,idx]
    irrel = tensor - rel
    dim = list(np.cumsum(dim))
    
    x_rel = []
    x_irrel = []
    for i in range(l):
        x_rel.append(rel[...,dim[i]:dim[i+1]])
        x_irrel.append(irrel[...,dim[i]:dim[i+1]])
    return tuple(x_rel), tuple(x_irrel)

In [10]:
# get a batch of images
data = iter(test_loader).next()[0]
data = data.to(device)

# wavelet transform and reconstruction
data_t = wt(data) # select some coefs of data_t to interpret (so maybe zero all but one coefs)

In [11]:
cd = []
for idx in range(72):
    rel, _ = track_wave_rel(data_t, idx)
    rel = wt.inverse(rel)
    irrel = data - rel
    
    rel = rel.permute(0, 2, 1)
    irrel = irrel.permute(0, 2, 1)    
    
    rel, irrel = cd_propagate.propagate_lstm_block(x_rel=rel, x_irrel=irrel,
                                                   module=model.model.lstm,
                                                   start=0, stop=10, my_device=device)    
    rel = rel.squeeze(1)
    irrel = irrel.squeeze(1)
    rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
    cd.append(rel)
cd = torch.cat(cd, 1)

In [12]:
cd

tensor([[ 0.1462,  0.1462,  0.1462,  ...,  0.1473,  0.1460,  0.1461],
        [ 0.1630,  0.1630,  0.1630,  ...,  0.1636,  0.1628,  0.1630],
        [ 0.0896,  0.0896,  0.0896,  ...,  0.0887,  0.0896,  0.0896],
        ...,
        [ 0.1636,  0.1636,  0.1636,  ...,  0.1647,  0.1635,  0.1636],
        [-0.1761, -0.1761, -0.1761,  ..., -0.1678, -0.1766, -0.1762],
        [-0.0515, -0.0515, -0.0515,  ..., -0.0451, -0.0509, -0.0515]],
       device='cuda:0', grad_fn=<CatBackward>)