In [15]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import random

import numpy as np
import torch
import matplotlib.pyplot as plt
opj = os.path.join
import pickle as pkl

from ex_biology import p

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# adaptive-wavelets modules
from awd import awd
from awd.mdata.biology import get_dataloader, load_pretrained_model
from awd.awd.utils import get_wavefun, get_1dfilts
from awd.awd.visualize import plot_1dfilts, plot_wavefun

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
p.wave = 'db5'
p.J = 4
p.mode = 'zero'
p.init_factor = 0
p.noise_factor = 0.2
p.const_factor = 0
p.num_epochs = 600
p.attr_methods = 'Saliency'

lamWaveloss = 1
p.lamlSum = lamWaveloss
p.lamhSum = lamWaveloss
p.lamL2sum = lamWaveloss
p.lamCMF = lamWaveloss
p.lamConv = lamWaveloss
p.lamHighfreq = lamWaveloss
p.lamL1wave = 0.0001
p.lamL1attr = 0.5
p.target = 0

In [17]:
# load data and model
train_loader, test_loader = get_dataloader(p.data_path,
                                           batch_size=p.batch_size,
                                           is_continuous=p.is_continuous)

model = load_pretrained_model(p.model_path, device=device)

# prepare model
random.seed(p.seed)
np.random.seed(p.seed)
torch.manual_seed(p.seed)

wt = awd.DWT1d(wave=p.wave, mode=p.mode, J=p.J,
               init_factor=p.init_factor,
               noise_factor=p.noise_factor,
               const_factor=p.const_factor).to(device)
wt.train()

# train
params = list(wt.parameters())
optimizer = torch.optim.Adam(params, lr=p.lr)
loss_f = awd.get_loss_f(lamlSum=p.lamlSum, lamhSum=p.lamhSum, lamL2norm=p.lamL2norm, lamCMF=p.lamCMF,
                        lamConv=p.lamConv, lamL1wave=p.lamL1wave, lamL1attr=p.lamL1attr, lamHighfreq=p.lamHighfreq)
trainer = awd.Trainer(model, wt, optimizer, loss_f, target=p.target,
                      use_residuals=True, attr_methods=p.attr_methods, device=device, n_print=5)

In [None]:
# run
trainer(train_loader, epochs=p.num_epochs)

Starting Training Loop...
====> Epoch: 0 Average train loss: 31.6549
====> Epoch: 5 Average train loss: 11.4980
====> Epoch: 10 Average train loss: 8.3002
====> Epoch: 15 Average train loss: 6.0572
====> Epoch: 20 Average train loss: 4.5766
====> Epoch: 25 Average train loss: 3.0800
====> Epoch: 30 Average train loss: 1.6291
====> Epoch: 35 Average train loss: 0.9711
====> Epoch: 40 Average train loss: 0.7325
====> Epoch: 45 Average train loss: 0.5858
====> Epoch: 50 Average train loss: 0.4777
====> Epoch: 55 Average train loss: 0.4031
====> Epoch: 60 Average train loss: 0.3506
====> Epoch: 65 Average train loss: 0.3199
====> Epoch: 70 Average train loss: 0.2998
====> Epoch: 75 Average train loss: 0.2909
====> Epoch: 80 Average train loss: 0.2854
====> Epoch: 85 Average train loss: 0.2819
====> Epoch: 90 Average train loss: 0.2810
====> Epoch: 95 Average train loss: 0.2808
====> Epoch: 100 Average train loss: 0.2803
====> Epoch: 105 Average train loss: 0.2804
====> Epoch: 110 Average t

In [None]:
plt.plot(np.log(trainer.train_losses))
plt.xlabel("epochs")
plt.ylabel("log train loss")
plt.title('Log-train loss vs epochs')
plt.show()

In [None]:
print('calculating losses and metric...')
model.train()  # cudnn RNN backward can only be called in training mode
validator = awd.Validator(model, test_loader)
rec_loss, lsum_loss, hsum_loss, L2norm_loss, CMF_loss, conv_loss, L1wave_loss, L1saliency_loss, L1inputxgrad_loss = validator(
    wt, target=p.target)
print("Recon={:.5f}\n lsum={:.5f}\n hsum={:.5f}\n L2norm={:.5f}\n CMF={:.5f}\n conv={:.5f}\n L1wave={:.5f}\n Saliency={:.5f}\n Inputxgrad={:.5f}\n".format(rec_loss,
                                      lsum_loss,
                                      hsum_loss,
                                      L2norm_loss,
                                      CMF_loss,
                                      conv_loss,
                                      L1wave_loss,
                                      L1saliency_loss,
                                      L1inputxgrad_loss))


In [None]:
filt = get_1dfilts(wt)
phi, psi, x = get_wavefun(wt)

plot_1dfilts(filt, is_title=True, figsize=(2,2))
plot_wavefun((phi, psi, x), is_title=True, figsize=(2,1))