In [24]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from copy import deepcopy
import pickle as pkl

from ex_cosmology import p

# adaptive-wavelets modules
import awave
from awave.data.cosmology import get_dataloader, load_pretrained_model
from awave.data.cosmology import get_validation
from awave.utils.misc import tuple_to_tensor, get_2dfilts
from awave.utils.wave_attributions import Attributer
from awave.trim import TrimModel

# evaluation
from eval_cosmology import load_results, rmse_bootstrap, extract_patches
from peak_counting import PeakCount, rmse

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# get dataloader and model
train_loader, val_loader = get_dataloader(p.data_path,
                                          img_size=p.img_size[2],
                                          split_train_test=True,
                                          batch_size=p.batch_size)

model = load_pretrained_model(model_name='resnet18', device=device, data_path=p.model_path)

# validation dataset
test_loader = get_validation(p.data_path,
                             img_size=p.img_size[2],
                             batch_size=p.batch_size)

# load results

In [3]:
dirs = [
    "db5_saliency_warmstart_seed=1_new"
]
dics, _, _ = load_results(dirs, include_interp_loss=False)

# select optimal bin using heldout dataset

In [7]:
# DB5
wt_o = awave.DWT2d(wave='db5', mode='zero', J=4,
                   init_factor=1, noise_factor=0, const_factor=0)

# extract kernels
kern_list = []
for wt in [wt_o] + list(dics[0]['wt'].values()):
    filt = get_2dfilts(wt)
    h = filt[0][0]
    g = filt[0][1]
    kern_list.append(extract_patches(h, g))

bds = np.linspace(0.015, 0.035, 5)
scores = np.zeros((len(bds), len(kern_list)))

for i, b in enumerate(bds):
    for j, kernels in enumerate(kern_list):
        pcw = PeakCount(peak_counting_method='custom',
                        bins=np.linspace(0, b, 23),
                        kernels=kernels)
        pcw.fit(train_loader)
        y_preds, y_params = pcw.predict(val_loader)
        scores[i, j] = rmse(y_params, y_preds)
        pkl.dump(scores, open('results/scores_ablation.pkl', 'wb'))
        print(
            "\riteration bd={}/{} kern={}/{}".format(
                i + 1, len(bds), j + 1, len(kern_list)
            ),
            end="",
        )

print('\n', np.min(scores))

iteration bd=5/5 kern=4/4
 0.014662965005035403


# optimal filter

In [23]:
# load optimal wavelet for prediction on heldout dataset
scores = pkl.load(open('results/scores_ablation.pkl', 'rb'))
row, col = np.unravel_index(np.argmin(scores, axis=None), scores.shape)
bd_opt = bds[row]
idx1, idx2 = list(dics[0]['wt'].keys())[col - 1]
wt = dics[0]['wt'][(idx1, idx2)]
lamL1wave = dics[0]['lamL1wave'][(idx1, idx2)]
lamL1attr = dics[0]['lamL1attr'][(idx1, idx2)]
print('lambda: {} gamma: {}'.format(lamL1wave, lamL1attr))

# AWD prediction performance
filt = get_2dfilts(wt)
h = filt[0][0]
g = filt[0][1]
kernels = extract_patches(h, g)
pcw = PeakCount(peak_counting_method='custom',
                bins=np.linspace(0, bd_opt, 23),
                kernels=kernels)
pcw.fit(train_loader)
y_preds, y_params = pcw.predict(test_loader)
acc, std = rmse_bootstrap(y_preds, y_params)
print("AWD: ", acc, std)

# original wavelet prediction performance
filt = get_2dfilts(wt_o)
h = filt[0][0]
g = filt[0][1]
kernels = extract_patches(h, g)
pcw = PeakCount(peak_counting_method='custom',
                bins=np.linspace(0, bds[np.argmin(scores[:, 0])], 23),
                kernels=kernels)
pcw.fit(train_loader)
y_preds, y_params = pcw.predict(test_loader)
acc, std = rmse_bootstrap(y_preds, y_params)
print("DB5: ", acc, std)

lambda: 0.005 gamma: 0.0
AWD:  0.013533081788635406 0.0004662344866841157
DB5:  0.015692681086327664 0.00048067312692403594


# compression

In [25]:
# define trim model
device = 'cpu'
mt = TrimModel(model, wt.inverse, use_residuals=True)
mt_o = TrimModel(model, wt_o.inverse, use_residuals=True)
attributer = Attributer(mt, attr_methods='Saliency', device=device)
attributer_o = Attributer(mt_o, attr_methods='Saliency', device=device)

# compute compression rate and representations
attrs = {'AWD': torch.tensor([]),
         'DB5': torch.tensor([])}
reps = {'AWD': torch.tensor([]),
        'DB5': torch.tensor([])}
wt, wt_o = wt.to(device), wt_o.to(device)
for data, _ in test_loader:
    data = data.to(device)
    i = 0
    for w in [wt, wt_o]:
        if i == 0:
            data_t = w(data)
            with torch.backends.cudnn.flags(enabled=False):
                attributions = attributer(data_t, target=0, additional_forward_args=deepcopy(data))
            y, _ = tuple_to_tensor(data_t)
            reps['AWD'] = torch.cat((reps['AWD'], y.detach().cpu()), dim=0)
            z, _ = tuple_to_tensor(attributions)
            attrs['AWD'] = torch.cat((attrs['AWD'], z.detach().cpu()), dim=0)
        else:
            data_t = w(data)
            with torch.backends.cudnn.flags(enabled=False):
                attributions = attributer_o(data_t, target=0, additional_forward_args=deepcopy(data))
            y, _ = tuple_to_tensor(data_t)
            reps['DB5'] = torch.cat((reps['DB5'], y.detach().cpu()), dim=0)
            z, _ = tuple_to_tensor(attributions)
            attrs['DB5'] = torch.cat((attrs['DB5'], z.detach().cpu()), dim=0)
        i += 1
reps['AWD'] = reps['AWD'].reshape(-1)
reps['DB5'] = reps['DB5'].reshape(-1)
attrs['AWD'] = attrs['AWD'].reshape(-1)
attrs['DB5'] = attrs['DB5'].reshape(-1)

thresh1 = 1e-3
thresh2 = 1e-3
c_rate_AWD = 1.0 * ((abs(reps['AWD']) > thresh1) & (abs(attrs['AWD']) > thresh2)).sum() / reps['AWD'].shape[0]
c_rate_DB5 = 1.0 * ((abs(reps['DB5']) > thresh1) & (abs(attrs['DB5']) > thresh2)).sum() / reps['DB5'].shape[0]
print(c_rate_AWD.item(), c_rate_DB5.item())

0.6155855059623718 0.6201185584068298
