In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
from tqdm import tqdm
import pickle as pkl
import pandas as pd
import torch.nn.functional as F

from ex_biology import p
from dset import get_dataloader, load_pretrained_model

# adaptive-wavelets modules
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_1dfilts, get_wavefun, low_to_high
from wave_attributions import Attributer
from visualize import cshow, plot_1dfilts, plot_1dreconstruct, plot_wavefun

# evaluation
from matplotlib import gridspec
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV, LogisticRegression
from sklearn import metrics
from sklearn.model_selection import cross_val_score
from feature_transform import max_transformer

# trim model
from trim import TrimModel

# cd
import acd
from acd.scores import cd_propagate

# load data and model

In [2]:
# load data and model
train_loader, test_loader = get_dataloader(p.data_path, 
                                           batch_size=p.batch_size,
                                           is_continuous=p.is_continuous)   

model = load_pretrained_model(p.model_path, device=device)    

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0.0).to(device)

# get a batch of images
data = iter(test_loader).next()[0]
data = data.to(device)

# wavelet transform and reconstruction
data_t = wt(data)
recon = wt.inverse(data_t)

In [3]:
print('shape of data / recon', data.shape, recon.shape)
for i in range(len(data_t)):
    print('shape of wavelet coeffs', data_t[i].shape)

shape of data / recon torch.Size([100, 1, 40]) torch.Size([100, 1, 40])
shape of wavelet coeffs torch.Size([100, 1, 10])
shape of wavelet coeffs torch.Size([100, 1, 24])
shape of wavelet coeffs torch.Size([100, 1, 16])
shape of wavelet coeffs torch.Size([100, 1, 12])
shape of wavelet coeffs torch.Size([100, 1, 10])


In [4]:
print('reconstruction error: {:.5f}'.format(torch.norm(data - recon).item()/data.size(0)))

reconstruction error: 0.00000


In [5]:
# trim model
mt = TrimModel(model, wt.inverse, use_residuals=True).to(device)
out = mt(data_t, deepcopy(data))
print(out[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0', grad_fn=<SliceBackward>)


In [6]:
print(model(data)[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0')


# calculate cd score in original domain

In [7]:
# data = data.permute(0, 2, 1)
# rel, irrel = cd_propagate.propagate_lstm(data, model.model.lstm, start=0, stop=30, my_device=device)
# rel = rel.squeeze(1)
# irrel = irrel.squeeze(1)
# rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
# rel + irrel

# compute cd score in wavelet domain

In [57]:
# get a batch of images
data = iter(test_loader).next()[0]
data = data.to(device)

# wavelet transform and reconstruction
data_t = wt(data) # select some coefs of data_t to interpret (so maybe zero all but one coefs)

In [147]:
# x = torch.ones_like(recon[0:1]) # starts (batch_size, num_channels, seq_len)
# x.data = recon[0:1]
# x = x.permute(0, 2, 1) # should be (batch_size, seq_len, num_channels)

# with torch.no_grad():
#     rel, irrel = cd_propagate.propagate_lstm(x, model.model.lstm, start=0, stop=30, my_device=device)
# rel = rel.squeeze(1)
# irrel = irrel.squeeze(1)
# rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
# rel + irrel

# x = recon
# model(x)

In [125]:
def track_wave_rel(x, idx):
    l = len(x)
    dim = [0]
    tensor = []
    for i in range(l):
        d = x[i].size(2)
        dim.append(d)
        tensor.append(x[i])
    tensor = torch.cat(tensor, 2)
    rel = torch.zeros_like(tensor)
    rel[...,idx] = tensor[...,idx]
    irrel = tensor - rel
    dim = list(np.cumsum(dim))
    
    x_rel = []
    x_irrel = []
    for i in range(l):
        x_rel.append(rel[...,dim[i]:dim[i+1]])
        x_irrel.append(irrel[...,dim[i]:dim[i+1]])
    return tuple(x_rel), tuple(x_irrel)

In [128]:
cd = []
for idx in range(1):
    rel, _ = track_wave_rel(data_t, idx)
    rel = wt.inverse(rel)
    irrel = data - rel
    
    rel = rel.permute(0, 2, 1)
    irrel = irrel.permute(0, 2, 1)    
    
    rel, irrel = cd_propagate.propagate_lstm_block(x_rel=rel, x_irrel=irrel,
                                                   module=model.model.lstm,
                                                   start=0, stop=10, my_device=device)    
    rel = rel.squeeze(1)
    irrel = irrel.squeeze(1)
    rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
    cd.append(rel)

In [130]:
rel + irrel

tensor([[-0.6547],
        [-0.6197],
        [-0.2235],
        [-0.6737],
        [ 0.0282],
        [ 0.1258],
        [-0.6188],
        [ 0.3074],
        [ 0.1007],
        [ 0.0284],
        [ 0.3567],
        [-0.6164],
        [ 0.6005],
        [ 0.3181],
        [ 0.2275],
        [ 0.9623],
        [-0.6379],
        [ 0.0789],
        [-0.6355],
        [-0.1220],
        [-0.6689],
        [ 1.1056],
        [-0.5101],
        [-0.6518],
        [-0.3684],
        [-0.3889],
        [-0.2986],
        [-0.3265],
        [ 1.0566],
        [-0.3358],
        [ 0.6626],
        [-0.6092],
        [ 0.1583],
        [-0.6744],
        [ 0.1731],
        [-0.0026],
        [-0.4538],
        [-0.6580],
        [-0.2977],
        [-0.5217],
        [-0.4511],
        [ 0.6573],
        [ 1.2003],
        [ 0.4695],
        [ 0.2023],
        [-0.6286],
        [-0.6472],
        [-0.1565],
        [-0.6102],
        [-0.1459],
        [-0.7757],
        [-0.4848],
        [-0.

In [36]:
len(x)

5

In [None]:
for

In [29]:
# rel, irrel (residuals shoul be put into irrel)

rel, irrel = cd_propagate.propagate_lstm_block(x_rel=rel, x_irrel=irrel,
                                               module=model.model.lstm,
                                               start=0, stop=10, my_device=device)
rel = rel.squeeze(1)
irrel = irrel.squeeze(1)
rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.model.fc)
print('cd score', rel)

cd score tensor([[-6.2875e-01],
        [-6.6938e-01],
        [-2.5597e-01],
        [-6.6279e-01],
        [ 1.2557e-01],
        [ 7.1962e-01],
        [-6.3370e-01],
        [ 5.2765e-01],
        [ 1.7939e-01],
        [ 7.7513e-03],
        [ 4.9782e-01],
        [-6.7419e-01],
        [ 1.1866e+00],
        [ 4.2525e-01],
        [ 1.7598e-01],
        [ 1.4416e+00],
        [-4.6629e-01],
        [ 1.2258e-01],
        [-6.6849e-01],
        [ 1.0506e-01],
        [-6.0307e-01],
        [ 1.9166e+00],
        [ 3.2136e-03],
        [-5.7529e-01],
        [-4.2015e-01],
        [ 7.3039e-03],
        [-2.6258e-01],
        [-7.7889e-03],
        [ 1.3842e+00],
        [-3.6915e-01],
        [ 5.3586e-01],
        [-9.8854e-02],
        [ 1.3889e-01],
        [-6.5502e-01],
        [ 2.2923e-01],
        [-1.3928e-03],
        [ 2.5419e-02],
        [-6.6141e-01],
        [ 2.9942e-01],
        [-4.7287e-01],
        [ 4.7872e-02],
        [ 1.0815e+00],
        [ 1.5550e+00],
  

In [12]:
wt.h0.grad

In [14]:
loss = rel.sum()

In [15]:
loss.backward()

In [16]:
wt.h0.grad

tensor([[[242.6617, 371.2155, 396.1845, 212.0899,  63.1081,  61.0499,  43.7076,
            8.9517,  -0.3973,  -8.5035]]], device='cuda:0')