In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from scipy.ndimage import gaussian_filter
import sys
from tqdm import tqdm
from functools import partial
import acd
from copy import deepcopy
sys.path.append('..')
sys.path.append('../trim')
sys.path.append('../trim/transforms')
from transforms_torch import bandpass_filter
# plt.style.use('dark_background')
sys.path.append('../../dsets/mnist')
import dset_mnist as dset
from model_mnist import Net, Net2c
from util import *
from numpy.fft import *
from torch import nn
from style import *
from captum.attr import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
# import modules
from acd_wooseok.acd.scores import cd
from funcs import *
from matfac import *

# load dataset and model

In [6]:
# params
img_size = 28
class_num = 1

In [7]:
# load args
args = dset.get_args()
args.epochs = 20
args.cuda = not args.no_cuda and torch.cuda.is_available()

# load the model
model = Net().to(device)
model.load_state_dict(torch.load('../../dsets/mnist/mnist.model', map_location=device))
model = model.eval()

# data_loader
train_loader, test_loader = dset.load_data(args.batch_size,
                                           args.test_batch_size,
                                           device,
                                           return_indices=True)

# conv sparse coding

In [8]:
# set hyper-params and variables
n_components = 32
kernel_size = 7
n_dim = kernel_size + (img_size-1)

csc = Conv_SpCoding(kernel_size, n_dim, n_components).to(device)
# load checkpoint
# csc.load_state_dict(torch.load('./model/csc_maxCD_0.pth'))

# reg-parameter
lamb = 1.0e-3

# losses
n_inner_c = 10
n_inner_w = 100
losses = [1000]

# set optimizer
lr_c = 0.1
lr_w = 0.1

In [None]:
num_epochs = 20
for epoch in range(num_epochs):
    for batch_indx, (data, _, _) in enumerate(test_loader):
        X = data.to(device)
        n_batch = len(X)
        # initialize act maps
        csc.init_maps(n_batch)
        csc.to(device)
        optimizer = csc_optimizer(csc, lr_c, lr_w, lamb)
        
        # update weight
        unfreeze(csc, param='map')
        # inner loop
        for i in range(n_inner_w):
            optimizer.zero_grad() # clear the old gradients
            # comp loss
            X_ = csc()
            loss = torch.norm(X-X_)**2/(2*n_batch)    
            # backward
            loss.backward()
            # update step
            optimizer.step(1)    

        reg_loss = L1Reg_loss(csc, X, lamb)
        losses.append(reg_loss)      

        # update dict
        unfreeze(csc, param='dict', obj_type='csc')
        # inner loop
        for i in range(n_inner_c):
            optimizer.zero_grad() # clear the old gradients
            # comp loss
            X_ = csc()
            loss = torch.norm(X-X_)**2/(2*n_batch)
            # backward
            loss.backward()
            # update step
            optimizer.step(0)    

        reg_loss = L1Reg_loss(csc, X, lamb)
        losses.append(reg_loss)             

        # recon-error, nnz
        err = torch.norm(X-X_).data.item() / torch.norm(X).data.item() 
        nnz = 0
        for feature_map in csc.maps:
            nnz += np.count_nonzero(feature_map.data.cpu().numpy())
        nnz_W = nnz/(n_dim*n_dim*n_components*n_batch)

        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tError: {:.6f}\tLoss: [{:.6f}, {:.6f}, {:.6f}]\tNNZ: {:.6f}%'.format(
            epoch, batch_indx * len(data), len(train_loader.dataset), err, 
            100. * batch_indx / len(train_loader), losses[-3], losses[-2], losses[-1], 100. * nnz_W), end='')         



In [None]:
from skimage.transform import rescale

n_row = 4
n_col = 8
Nimages = len(csc.convs)

plt.figure(figsize=(25,25))
# plot filters
plt.subplot(1, 2, 1)
p = kernel_size + 2
mosaic = np.zeros((p*n_row,p*n_col))
indx = 0
for i in range(n_row):
    for j in range(n_col):
        im = csc.convs[indx].weight.data.cpu().squeeze().numpy()
        im = (im-np.min(im))
        im = im/np.max(im)
        mosaic[i*p:(i+1)*p,j*p:(j+1)*p] = np.pad(im,(1,1),mode='constant')
        indx += 1

plt.imshow(rescale(mosaic,4,mode='constant'), cmap='gray')
plt.axis('off')    
plt.show()

In [None]:
# num_epochs = 50
# for epoch in range(num_epochs):
#     # initialize act maps
#     csc.init_maps(n_batch)
#     csc.to(device)
#     optimizer = csc_optimizer(csc, lr_c, lr_w, lamb)
    
#     # update weight
#     unfreeze(csc, param='map')
#     # inner loop
#     for i in range(n_inner_w):
#         optimizer.zero_grad() # clear the old gradients
#         # comp loss
#         X_ = csc()
#         loss = torch.norm(X-X_)**2/(2*n_batch)    
#         # backward
#         loss.backward()
#         # update step
#         optimizer.step(1)    

#     reg_loss = L1Reg_loss(csc, X, lamb)
#     losses.append(reg_loss)      

#     # update dict
#     unfreeze(csc, param='dict', obj_type='csc')
#     # inner loop
#     for i in range(n_inner_c):
#         optimizer.zero_grad() # clear the old gradients
#         # comp loss
#         X_ = csc()
#         loss = torch.norm(X-X_)**2/(2*n_batch)
#         # backward
#         loss.backward()
#         # update step
#         optimizer.step(0)    

#     reg_loss = L1Reg_loss(csc, X, lamb)
#     losses.append(reg_loss)             

#     if epoch % 1 == 0:
#         # recon-error, nnz
#         err = torch.norm(X-X_).data.item() / torch.norm(X).data.item() 
#         nnz = 0
#         for feature_map in csc.maps:
#             nnz += np.count_nonzero(feature_map.data.cpu().numpy())
#         nnz_W = nnz/(n_dim*n_dim*n_components*n_batch)

#         print('\rTrain Epoch: {} [({:.0f}%)]\tError: {:.6f}\tLoss: [{:.6f}, {:.6f}, {:.6f}]\tNNZ: {:.6f}%'.format(
#             epoch, 100. * epoch / num_epochs, err, losses[-3], losses[-2], losses[-1], 100. * nnz_W), end='')         