In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from copy import deepcopy
import pickle as pkl
from tqdm import tqdm
import acd

from sim_biology import p, load_dataloader_and_pretrained_model
p.data_path = "../../src/dsets/biology/data"
# adaptive-wavelets modules
sys.path.append('../../src')
sys.path.append('../../src/adaptive_wavelets')
from losses import get_loss_f
from train import Trainer
from evaluate import Validator
from transform1d import DWT1d
from utils import get_1dfilts
from wave_attributions import Attributer
from visualize import cshow, plot_1dreconstruct, plot_1dfilts
sys.path.append('../../lib/trim')
from trim import TrimModel

## load data and model

In [2]:
# get dataloader and model
(train_loader, test_loader), model = load_dataloader_and_pretrained_model(p)

# wavelet transform 
wt = DWT1d(wave='db5', mode='zero', J=4, init_factor=1, noise_factor=0.0).to(device)

# get a batch of images
data = iter(test_loader).next()[0].to(device)
# wavelet transform and reconstruction
data_t = wt(data)
recon = wt.inverse(data_t)

In [3]:
print('shape of data', data.shape)
for i in range(len(data_t)):
    print('shape of wavelet coeffs', data_t[i].shape)

shape of data torch.Size([100, 1, 40])
shape of wavelet coeffs torch.Size([100, 1, 10])
shape of wavelet coeffs torch.Size([100, 1, 24])
shape of wavelet coeffs torch.Size([100, 1, 16])
shape of wavelet coeffs torch.Size([100, 1, 12])
shape of wavelet coeffs torch.Size([100, 1, 10])


In [4]:
print('reconstruction error: {:.5f}'.format(torch.norm(data - recon).item()/data.size(0)))

reconstruction error: 0.00000


In [5]:
# trim model
mt = TrimModel(model, wt.inverse, use_residuals=True).to(device)
out = mt(data_t, deepcopy(data))
print(out[:5])

tensor([[-0.8557],
        [-0.8095],
        [-0.4112],
        [-0.8514],
        [ 0.0790]], device='cuda:0', grad_fn=<SliceBackward>)
