In [1]:
import torch
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt

from torchvision.models import alexnet
from torchvision.transforms import Compose, ToTensor, Resize, Grayscale
from torch.utils.data import DataLoader
from PytorchStorage import ForwardModuleStorage
from sklearn.cluster import KMeans
from tqdm import tqdm_notebook as tqdm
from collections import OrderedDict
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from fastai.vision import *
from fastai.layers import CrossEntropyFlat


In [2]:
%matplotlib notebook

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN = False

In [4]:
def pca(x, k=2):
    """
    From http://agnesmustar.com/2017/11/01/principal-component-analysis-pca-implemented-pytorch/
    """
    # preprocess the data
    X_mean = torch.mean(x,0)
    x = x - X_mean.expand_as(x)
    # svd
    U,S,V = torch.svd(torch.t(x))
    return torch.mm(x,U[:,:k])

## Dataset

In [5]:
from torchvision.datasets import MNIST
from fastai.vision import *

tr = Compose([Grayscale(), ToTensor()])

train_ds = MNIST(root='~/Documents/datasets/', download=True, transform=tr)
train_dl = DataLoader(train_ds, num_workers=14, batch_size=128, shuffle=True)

val_ds = MNIST(root='~/Documents/datasets/', train=False, transform=tr)
val_dl = DataLoader(val_ds, num_workers=14, batch_size=128, shuffle=False)

data = ImageDataBunch(train_dl, val_dl)

## Model

In [6]:
model = simple_cnn((1,16,32,64)).to(device)

learn = Learner(data, model, path='./', loss_func=CrossEntropyFlat())
learn.metrics=[accuracy]

### Train

In [7]:
if TRAIN:
    learn.fit(25)
    learn.save('learn', return_path=True)

In [8]:
learn.load('./learn')

learn.validate(metrics=[accuracy])

[0.11849265, tensor(0.9632)]

In [9]:
def tensor2numpy(*args):
    return [e.cpu().numpy() for e in args]

In [49]:
import pandas as pd
from dataclasses import dataclass

@dataclass
class State():
    points: torch.tensor = torch.empty(0)
    y: torch.tensor = torch.empty(0).long()
    indeces: torch.tensor = torch.empty(0).long()
    
    def __repr__(self):
        return f"points={self.points.shape}"
    
class ModulePCA():
    reducers = ['kmeans']
    """
    Apply and visualize PCA with k-features of a specific CNN-layer. 
    It computes the PCA values batch-wise to reduce memory usage and increase performance.
    """
    def __init__(self, module, layer, dataloader):
        self.module, self.layer = module, layer
        self.storage = ForwardModuleStorage(module, [layer])
        self.dataloader = dataloader
        self.state = State()
      
    def points(self, dataloader, k=2, n_batches=None):
        """
        Batch-wise PCA. It returns the pca points, the labels and the inputs as Pytorch Tensors.
        """
        for i, (x, y) in enumerate(dataloader):
            y, x = y.to(device), x.to(device)
            self.storage(x) # run input into the storage
            with torch.no_grad():
                features = self.storage[self.layer][0]
                flat_features = features.view(features.shape[0], -1)
                pca_features = pca(flat_features, k=k)
                del self.storage.state[self.layer] # reinit storage -> save memory
                self.storage.state[self.layer] = []
                if n_batches is not None and i == n_batches: break
                yield pca_features, y, x
    
    def __call__(self, *args, **kwargs):
        for points, y, x in tqdm(self.points(self.dataloader, *args, **kwargs)):
            # store points and labels by bringing them to the cpu to save GPU memory
            self.state.points = torch.cat([self.state.points, points.cpu()])
            self.state.y = torch.cat([self.state.y, y.cpu()])
        
        self.state.indeces = torch.arange(len(self.state.points))
        return self
    
    def reduce(self, to=100, using='kmeans'):
        if using not in ModulePCA.reducers: raise ValueError(f"Parameter 'using' must be one of {ModulePCA.reducers}")
        points = self.state.points
        
        bar = tqdm(total=1)
        bar.set_description(f"Reducing {self.state.points.shape[0]} points to {to} using {using}...")
        if using == 'kmeans':
            kmeans = KMeans(n_clusters=to)
            kmeans.fit(module_pca.state.points.numpy(), y=module_pca.state.y.numpy())
            points = kmeans.cluster_centers_
        # update labels and indeces using the position of the clusters
        y = [self.state.y.numpy()[np.where(kmeans.labels_ == i)][0] for i in range(kmeans.n_clusters)]
        indeces = [self.state.indeces.numpy()[np.where(kmeans.labels_ == i)][0] for i in range(kmeans.n_clusters)]

        reduced_module_pca =  ModulePCA(self.module, self.layer, self.dataloader)
        reduced_module_pca.state = State(torch.from_numpy(points), 
                                          torch.from_numpy(np.array(y)),
                                          torch.from_numpy(np.array(indeces)))
        
        bar.update(1)
        
        return reduced_module_pca
    
    def _scatter(self):
        points, y = tensor2numpy(self.state.points, self.state.y)

        for i, label in enumerate(np.unique(y).tolist()):
                self.ax.scatter(points[y == label, 0], points[y == label, 1], label=label, alpha=0.5)
    
    def _legend(self):
        """
        Remove duplicates name in the legend
        """
        handles, labels = self.ax.get_legend_handles_labels()
        by_label = OrderedDict(zip(labels, handles))
        self.ax.legend(by_label.values(), by_label.keys())
        
    
    def plot2d(self):
        self.fig = plt.figure()
        self.ax = plt.subplot(111)\
        
        title = f"{len(self.state.points)} points"
        
        self._scatter()
        self._legend()

        plt.title(title)
        
        return self
    
    def annotate(self, zoom=1):
        self.fig = plt.figure()
        self.ax = plt.subplot(111)
        self._scatter()               
        for point, i in zip(self.state.points.numpy(), self.state.indeces.numpy()):
            x, y = point[0], point[1]
            img = self.dataloader.dataset[i][0]
            img_np = img.permute(1,2,0).numpy().squeeze()
            ab = AnnotationBbox(OffsetImage(img_np, zoom=zoom), (x, y), frameon=False)
            self.ax.add_artist(ab)
            
        return self

SyntaxError: f-string: empty expression not allowed (<ipython-input-49-4cb8f8fc87f3>, line 55)

In [44]:
plt.rcParams['figure.figsize'] = [10, 10]
module_pca = ModulePCA(learn.model, learn.model[3][0], learn.data.valid_dl)
module_pca(k=2, n_batches=None)
module_pca.plot2d()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




<IPython.core.display.Javascript object>

<__main__.ModulePCA at 0x7f2a67cca6d8>

In [48]:
new_module_pca = module_pca.reduce(100)
new_module_pca.plot2d()
new_module_pca.annotate(zoom=0.5)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<__main__.ModulePCA at 0x7f2a67c026a0>