In [77]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from scipy.ndimage import gaussian_filter
import sys
from tqdm import tqdm
from functools import partial
import acd
from copy import deepcopy
sys.path.append('..')
from transforms_torch import bandpass_filter
plt.style.use('dark_background')
sys.path.append('../../dsets/mnist')
import dset
from model import Net
from util import *
from numpy.fft import *
from torch import nn
from style import *
from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel,
)
import pickle as pkl
from torchvision import datasets, transforms
from sklearn.decomposition import NMF
import transform_wrappers

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
im_torch, im_orig, label = dset.get_im_and_label(251, device=device) # this will download the mnist dataset
model = Net().to(device)
im_torch = im_torch.to(device)
model.load_state_dict(torch.load('../../dsets/mnist/mnist.model', map_location=device))
model = model.eval().to(device)

# nmf stuff

In [29]:
mnist = datasets.MNIST('mnist/data', train=True, download=True)
X = mnist.data.numpy().astype(np.float32)
X = X.reshape(X.shape[0], -1)

In [3]:
nmf = NMF(n_components=30, random_state=42)
# nmf.fit(X)
# pkl.dump(nmf, open('nmf_30.pkl', 'wb'))
nmf = pkl.load(open('nmf_30.pkl', 'rb'))

In [None]:
D = nmf.components_.reshape(-1, 28, 28)
print('D.shape', D.shape)

R, C = 5,6
i = 0
vmin = np.min(D)
vmax = np.max(D)
plt.figure(figsize=(C * 3, R * 3), dpi=200)
for r in range(R):
    for c in range(C):
        plt.subplot(R, C, i + 1)
        plt.imshow(D[i], vmin=vmin, vmax=vmax, cmap='viridis')
        plt.axis('off')
        i += 1
plt.tight_layout()
plt.show()

In [4]:
X_t = nmf.transform(X) # (n, num_bases)

In [5]:
D = nmf.components_ # (num_bases, input_size)

In [6]:
X_ = X_t @ D.astype(np.float32) # inverse transform

In [7]:
np.mean(np.square(X_ - X)) / np.mean(np.square(X))

0.2155938798355274

In [83]:
transform = transform_wrappers.lay_from_w(D)
norm = transform_wrappers.Norm_Layer(mu=0.1307, std=0.3081)
net = transform_wrappers.Net_with_transform(model, transform=transform, norm=norm).to(device)

In [84]:
x = X[:2]
x_t = nmf.transform(x)
x_t = torch.Tensor(x_t).to(device)

In [85]:
net(x_t)

post transform torch.Size([2, 784])


tensor([[-1259.3894, -2353.4309, -2388.2671,  -631.3194, -2564.6746,     0.0000,
         -1631.9843, -2212.8752, -1528.3231, -1069.9592],
        [    0.0000, -3568.0764, -2241.5349, -3170.4446, -3573.6001, -2630.3381,
         -2349.1653, -2713.8767, -2651.4229, -2509.8071]], device='cuda:0',
       grad_fn=<LogSoftmaxBackward>)

In [74]:
list(net.modules())

[Net_with_transform(
   (transform): Linear(in_features=30, out_features=784, bias=False)
   (norm): Norm_Layer()
   (model): Net(
     (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
     (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
     (conv2_drop): Dropout2d(p=0.5, inplace=False)
     (fc1): Linear(in_features=320, out_features=50, bias=True)
     (fc2): Linear(in_features=50, out_features=10, bias=True)
   )
 ), Linear(in_features=30, out_features=784, bias=False), Norm_Layer(), Net(
   (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
   (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
   (conv2_drop): Dropout2d(p=0.5, inplace=False)
   (fc1): Linear(in_features=320, out_features=50, bias=True)
   (fc2): Linear(in_features=50, out_features=10, bias=True)
 ), Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1)), Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1)), Dropout2d(p=0.5, inplace=False), Linear(in_features=320, out_features=50, bias=Tru

In [76]:
for mod in net.modules():
    print(type(mod))

<class '__main__.Net_with_transform'>
<class 'torch.nn.modules.linear.Linear'>
<class '__main__.Norm_Layer'>
<class 'model.Net'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.dropout.Dropout2d'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.linear.Linear'>
