In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from scipy.ndimage import gaussian_filter
import sys
from tqdm import tqdm
from functools import partial
import acd
from copy import deepcopy
sys.path.append('..')
sys.path.append('../..')
from transforms_torch import bandpass_filter
# plt.style.use('dark_background')
sys.path.append('../../dsets/mnist')
import dset
from model import Net, Net2c
from util import *
from numpy.fft import *
from torch import nn
from style import *
from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel,
)
import pickle as pkl
from torchvision import datasets, transforms
from sklearn.decomposition import NMF
import transform_wrappers
import visualize as viz
from model import Net, Net2c
torch.manual_seed(42)
np.random.seed(42)

In [None]:
from acd_wooseok.acd.scores import cd
from acd_wooseok.acd.util import tiling_2d
from acd_wooseok.acd.scores import score_funcs
from torchvision import datasets, transforms
# import modules
from funcs import *
from matfac import *

# Dataset

In [None]:
# load args
args = dset.get_args()
args.batch_size = int(args.batch_size/2) # half the batchsize
args.epochs = 50
args.cuda = not args.no_cuda and torch.cuda.is_available()

# load mnist dataloader
train_loader, test_loader = dset.load_data_with_indices(args.batch_size, args.test_batch_size, device)

# dataset
X = train_loader.dataset.data.numpy().astype(np.float32)
X = X.reshape(X.shape[0], -1)
X /= 255
Y = train_loader.dataset.targets.numpy()

X_test = test_loader.dataset.data.numpy().astype(np.float32)
X_test = X_test.reshape(X_test.shape[0], -1)
X_test /= 255
Y_test = test_loader.dataset.targets.numpy()

# load NMF object
# run NMF
# nmf = NMF(n_components=30, max_iter=1000)
# nmf.fit(X)
# pkl.dump(nmf, open('./results/nmf_30.pkl', 'wb'))
nmf = pkl.load(open('./results/nmf_30.pkl', 'rb'))
D = nmf.components_
# nmf transform
W = nmf.transform(X) 
W_test = nmf.transform(X_test) 

In [None]:
# knockout first dictionary and redefine train and test dataset
indx = np.argwhere(nmf.transform(X)[:,0] > 0).flatten()
indx_t = np.argwhere(nmf.transform(X_test)[:,0] > 0).flatten()

# subset dataloader
train_loader, test_loader = dset.load_data_with_indices(args.batch_size, 
                                                        args.test_batch_size, 
                                                        device, 
                                                        subset_index=[indx, indx_t])

# Train model

In [None]:
def nmf_transform(W: np.array, data_indx, list_dict_indx=[0]):
    im_parts = W[data_indx][:,list_dict_indx] @ D[list_dict_indx] / 0.3081
    im_parts = torch.Tensor(im_parts).reshape(batch_size, 1, 28, 28)
    return im_parts

def nmf_knockout_augment(im: torch.Tensor, W: np.array, data_indx, list_dict_indx=[0]):
    batch_size = im.size()[0]
    im_copy = deepcopy(im)
    im_parts = nmf_transform(W, data_indx, list_dict_indx)
    im_copy = torch.cat((im_copy,im-im_parts), dim=0)
    return im_copy

In [None]:
# set seed
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    
# create model
model = Net2c()
if args.cuda:
    model.cuda()  
    
# optimizer
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 

# train
for epoch in range(1, args.epochs + 1):
    model.train()
    for batch_indx, (data, target, data_indx) in enumerate(train_loader):
        batch_size = len(data)
        data = nmf_knockout_augment(data, W, data_indx, list_dict_indx=[0])
        target = torch.zeros(2*batch_size, dtype=target.dtype)
        target[batch_size:] = 1
        if args.cuda:
            data, target = data.cuda(), target.cuda()     
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_indx % args.log_interval == 0:
            print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_indx * len(data), 2*len(train_loader.dataset),
                       100. * batch_indx / len(train_loader), loss.data.item()), end='')        

# Test model

In [None]:
# eval mode
model.eval()
if args.cuda:
    model.cuda()

# test
test_loss = 0
correct = 0
for batch_indx, (data, target, data_indx) in tqdm(enumerate(test_loader)):
    batch_size = len(data)
    data = nmf_knockout_augment(data, W_test, data_indx, list_dict_indx=[0])
    target = torch.zeros(2*batch_size, dtype=target.dtype)
    target[batch_size:] = 1
    if args.cuda:
        data, target = data.cuda(), target.cuda()       
    output = model(data)
    test_loss += F.nll_loss(output, target, reduction='sum').data.item()  # sum up batch loss
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= 2*len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, 2*len(test_loader.dataset),
    100. * correct / (2*len(test_loader.dataset))))

# CD score

In [None]:
test_num = len(indx_t)

# true band centers model is trained with
scores_o = torch.zeros(test_num, nmf.n_components) # cd score for class 0 (original img)
scores_f = torch.zeros(test_num, nmf.n_components) # cd score for class 1 (transformed img)

n = 0
for batch_indx, (data, target, data_indx) in enumerate(test_loader):
    batch_size = len(data)
    data_f = data - nmf_transform(W_test, data_indx, list_dict_indx=[0])
    # eval mode
    model.eval()
    if args.cuda:
        model.cuda()    
    for comp_indx in range(nmf.n_components):
        im_parts = nmf_transform(W_test, data_indx, list_dict_indx=[comp_indx])
        score_o = cd.cd(data, model, mask=None, model_type='mnist', device='cuda',
                                   transform=None, relevant=im_parts)[0].data.cpu()
        score_f = cd.cd(data-im_parts, model, mask=None, model_type='mnist', device='cuda',
                                   transform=None, relevant=im_parts)[0].data.cpu()
        
        scores_o[n:n+batch_size,comp_indx] = score_o[:,0]
        scores_f[n:n+batch_size,comp_indx] = score_o[:,1]
        print('\r batch index: {} [component index: {}]'.format(batch_indx, comp_indx), end='')     
    n += batch_size

In [None]:
list_of_x = np.arange(nmf.n_components)
fig, ax = plt.subplots(1, 2, figsize=(13,5))
ax[0].plot(list_of_x, scores_o.mean(axis=0), alpha=0.5, color='blue', label='class(original)', linewidth=4.0)
ax[0].fill_between(list_of_x, scores_o.mean(axis=0)-scores_o.std(axis=0), 
                    scores_o.mean(axis=0)+scores_o.std(axis=0), color='#888888', alpha=0.4)
ax[0].axvline(x=0, linestyle='--', color='green', label='true band center', linewidth=2.0)
ax[0].set_xlabel('dictionary index')
ax[0].set_ylabel('cd score')
ax[0].set_title('Averaged CD score')
# ax[0].set_title('Test accuracy: {}/{} ({:.0f}%)'.format(accuracies[band_idx], 2*len(test_loader.dataset),
#         100. * accuracies[band_idx] / (2*len(test_loader.dataset))))
# ax[0].legend()

ax[1].hist(np.argmax(scores_o,axis=1), bins=30, alpha=0.4)
ax[1].axvline(x=0, linestyle='--', color='green', label='true band center', linewidth=2.0)
ax[1].legend()
ax[1].set_xlabel('dictionary index')
ax[1].set_ylabel('frequency')
ax[1].set_title('Maximum dictionary index for each data point')
plt.show()