In [None]:
from sklearn import neighbors, cluster
from sklearn.metrics import pairwise
from sklearn.manifold import SpectralEmbedding

from scipy.spatial.distance import pdist, squareform
import pygsp
import torch
import numpy as np


def create_graph_from_embedding(embedding, name, k=10, n_clusters=8):
    latent_dim, batch_size = embedding.shape
    if name =='gaussian':
        # Compute a gaussian kernel over the node activations
        node_distances = squareform(pdist(embedding, 'sqeuclidean'))
        s = 1
        K = np.exp(-node_distances / s**2)
        K[K < 0.1] = 0
        A = K * (np.ones((latent_dim, latent_dim)) - np.identity(latent_dim))
        return A
    elif name == 'knn':
        mat = neighbors.kneighbors_graph(embedding, n_neighbors=k, metric='cosine', mode='distance')
        mat.data = 1 - mat.data
        A = mat.toarray()
        A = (A + A.T) / 2
        return A
    elif name == 'knn-flat':
        A = neighbors.kneighbors_graph(embedding, n_neighbors=k, metric='cosine').toarray()
        A = (A + A.T) / 2
        return A
    elif name == 'adaptive': # It's super slow
        # Find distance of k-th nearest neighbor and set as bandwidth
        neigh = neighbors.NearestNeighbors(n_neighbors=3)
        neigh.fit(embedding)
        dist, _ = neigh.kneighbors(embedding, return_distance=True)
        kdist = dist[:,-1]
        # Apply gaussian kernel with adaptive bandwidth
        node_distances = squareform(pdist(embedding, 'sqeuclidean'))
        K = np.exp(-node_distances / kdist**2)
        A = K * (np.ones((latent_dim, latent_dim)) - np.identity(latent_dim))
        A = (A + np.transpose(A)) / 2 # Symmetrize knn graph
        return A
    elif name == 'full':
        A = pairwise.cosine_similarity(embedding)
        np.fill_diagonal(A, 0)
        return np.maximum(A, 0)
    elif name == 'hclust':
        d = pairwise.cosine_distances(embedding)
        clusts = cluster.AgglomerativeClustering(n_clusters=n_clusters, affinity="precomputed", linkage="average").fit(d).labels_
        A = np.zeros(d.shape)
        for i in range(clusts.max() + 1):
            A[np.ix_(clusts == i, clusts == i)] = 1.0

        np.fill_diagonal(A, 0)
        return A
    elif 'knn-spectral':
        mat = neighbors.kneighbors_graph(embedding, n_neighbors=k, metric='cosine', mode='distance')
        mat.data = 1 - mat.data
        A = mat.toarray()
        A = (A + A.T) / 2
        clusts = cluster.SpectralClustering(n_clusters=n_clusters, affinity='precomputed').fit(A).labels_

        mask = np.zeros(A.shape)
        for i in range(clusts.max() + 1):
            mask[np.ix_(clusts == i, clusts == i)] = 1.0

        return (A > 1e-5) * mask
    else:
        raise RuntimeError('Unknown graph name %s' % name)


def create_lap_from_embedding(embedding, *args, **kwargs):
    adj_mat = create_graph_from_embedding(embedding, *args, **kwargs)
    graph = pygsp.graphs.Graph(adj_mat)
    graph.compute_laplacian(lap_type='normalized')
    return torch.Tensor(graph.L.A)


def graph_loss(activations, lap):
    return (activations.mm(lap) * activations).sum() / activations.shape[1]


def create_graph_from_layered_embedding(embs, frac:float = 0.1, n_clusters:int = 0):
    n_hidden = len(embs)
    layers = [e.shape[0] for e in embs]

    neighs = [neighbors.NearestNeighbors(n_neighbors=int(frac * layers[i] + 1), metric='cosine').fit(embs[i]) for i in range(n_hidden)]
    ids_per_layer = [sum(layers[:i]) + np.arange(layers[i], dtype=int) for i in range(n_hidden)]

    adj_mat = np.zeros((sum(layers), sum(layers)))

    for i in range(n_hidden):
        dist, ids = [x[:, 1:] for x in neighs[i].kneighbors(embs[i], return_distance=True)]
        for v1,v2s in enumerate(ids):
            adj_mat[ids_per_layer[i][v1], ids_per_layer[i][v2s]] = 1 - dist[v1,:]

        if i != n_hidden - 1:
            dist, ids = [x[:, :-1] for x in neighs[i + 1].kneighbors(embs[i], return_distance=True)]
            for v1,v2s in enumerate(ids):
                adj_mat[ids_per_layer[i][v1], ids_per_layer[i + 1][v2s]] = 1 - dist[v1,:]

        if i != 0:
            dist, ids = [x[:, :-1] for x in neighs[i - 1].kneighbors(embs[i], return_distance=True)]
            for v1,v2s in enumerate(ids):
                adj_mat[ids_per_layer[i][v1], ids_per_layer[i - 1][v2s]] = 1 - dist[v1,:]

    adj_mat = (adj_mat + adj_mat.T) / 2

    if n_clusters <= 0:
        return adj_mat

    clusts = cluster.SpectralClustering(n_clusters=n_clusters, affinity='precomputed').fit(adj_mat).labels_

    mask = np.zeros_like(adj_mat)
    for i in range(clusts.max() + 1):
        mask[np.ix_(clusts == i, clusts == i)] = 1.0

    return (adj_mat > 1e-5) * mask

In [2]:
import sys
sys.path.append("..")

%load_ext autoreload
%autoreload 2

import gc
import collections
from functools import partial

import seaborn as sns

import matplotlib.pyplot as plt

from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

%config Completer.use_jedi = False

dev = 'cuda:0'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


4

In [8]:
class Net(nn.Module):
    def __init__(self, layers, drop_p=0.2):
        super(Net,self).__init__()
        self.hidden = nn.ModuleList()
        for li,lo in zip(layers, layers[1:]):
            self.hidden.append(nn.Linear(li, lo))
        self.droput = nn.Dropout(drop_p)
        
    def forward(self,x):
        x = x.view(-1, INPUT_SHAPE)
        for i,l in enumerate(self.hidden):
            x = l(x)
            if i < len(self.hidden) - 1:
                x = self.droput(F.relu(x))
        return x

def eval_nn(nn, testloader):
    correct = 0
    loss_sublist = []
    for x,y in testloader:
        x,y = x.to(dev), y.to(dev)
        nn.eval()
        z = nn(x)
        _, yh = torch.max(z.data, 1)
        correct += (yh == y).sum().item()
        loss_sublist.append(crit(z, y).data.item())
    acc = correct / n_test
    return acc, round(np.mean(loss_sublist), 4)

def run_nn_train(nn, x, y, optimizer, crit):
    x,y = x.to(dev), y.to(dev)
    nn.train()
    optimizer.zero_grad()
    z = nn(x)
    return crit(z, y)

def clean_mem():
    gc.collect()
    torch.cuda.empty_cache()

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainset_raw = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, shuffle=True, num_workers=12)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, shuffle=False, num_workers=12)

INPUT_SHAPE = 1 * 28 * 28
OUTPUT_SHAPE = 10

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!




## Simple MLP

In [7]:
layers = [INPUT_SHAPE, 40, 40, 30, OUTPUT_SHAPE]
mlp = Net(layers, drop_p=0.3).to(dev)

crit = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)

AssertionError: Torch not compiled with CUDA enabled

In [11]:
%%time

N_EPOCHS = 20
train_loss_list = []
test_loss_list = []
accuracy_list = []
correct = 0
n_test = len(testset)
graph = None

for e in range(N_EPOCHS):
    activations = collections.defaultdict(list)
    # Train
    loss_sublist = []
    gl_sublist = []
    for x,y in trainloader:
        loss = run_nn_train(mlp, x, y, optimizer=optimizer, crit=crit)
        loss_sublist.append(loss.data.item())

        loss.backward()
        optimizer.step()
        
    train_loss_list.append(np.mean(loss_sublist))

    # Test
    acc, test_loss = eval_nn(mlp, testloader)
    del x,y,loss
    clean_mem()

    test_loss_list.append(test_loss)
    print(f'{e}. Accuracy: {round(acc, 4)}, CE: {test_loss_list[-1]}')

NameError: name 'mlp' is not defined

In [None]:
l1 = dict(mlp.named_modules())['hidden.0']
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(20, 14))
for i in range(16):
    sns.heatmap(list(l1.parameters())[0][i,:].reshape(28, 28).cpu().detach().numpy(), xticklabels=False, yticklabels=False, center=0, ax=axes.flatten()[i])

plt.tight_layout()

## Graph MLP

In [9]:
layers = [INPUT_SHAPE, 40, 40, 30, OUTPUT_SHAPE]
n_hidden = len(layers) - 2
ids_per_layer = [sum(layers[1:(i+1)]) + np.arange(layers[i+1], dtype=int) for i in range(n_hidden)]

graph_mlp = Net(layers, drop_p=0.3).to(dev)

crit = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(graph_mlp.parameters(), lr=0.001)

activations = collections.defaultdict(list)
def save_activation(name, mod, inp, out):
    activations[name].append(out.cpu())

for name, m in graph_mlp.named_modules():
    if type(m)==nn.Linear:
        m.register_forward_hook(partial(save_activation, name))

AssertionError: Torch not compiled with CUDA enabled

In [10]:
%%time

N_EPOCHS = 20
train_loss_list = []
test_loss_list = []
accuracy_list = []
correct = 0
n_test = len(testset)
lap = None

for e in range(N_EPOCHS):
    activations = collections.defaultdict(list)
    # Train
    loss_sublist = []
    gl_sublist = []
    for x,y in trainloader:
        loss = run_nn_train(graph_mlp, x, y, optimizer=optimizer, crit=crit)
        loss_sublist.append(loss.data.item())

        if lap is not None:
            gl = graph_loss(torch.hstack([activations[f'hidden.{li}'][-1] for li in range(n_hidden)]), lap) * 0.005
            gl_sublist.append(gl.data.item())
            loss += gl

        loss.backward()
        optimizer.step()
        
    train_loss_list.append(np.mean(loss_sublist))
    
    # Loss update
    embs = [torch.vstack(activations[f'hidden.{i}']).T.detach().numpy() for i in range(n_hidden)]
    adj_mat = create_graph_from_layered_embedding(embs, frac=0.3)
    graph = pygsp.graphs.Graph(adj_mat)
    graph.compute_laplacian(lap_type='normalized')
    lap = torch.Tensor(graph.L.A)

    # Test
    acc, test_loss = eval_nn(graph_mlp, testloader)
    del x,y,loss
    clean_mem()

    if len(gl_sublist) > 0:
        print(f'{e}. Accuracy: {round(acc, 4)}, CE: {test_loss}, GL: {round(np.mean(gl_sublist), 3)}')
    else:
        print(f'{e}. Accuracy: {round(acc, 4)}, CE: {test_loss}')

NameError: name 'graph_mlp' is not defined

In [None]:
embs = [torch.vstack(activations[f'hidden.{i}']).T.detach().numpy() for i in range(n_hidden)]
adj_mat = create_graph_from_layered_embedding(embs, frac=0.3)

spec_emb = SpectralEmbedding(affinity='precomputed').fit_transform(adj_mat)
clust_labels = cluster.k_means(spec_emb, 8)[1]

for i in set(clust_labels):
    mask = (clust_labels == i)
    plt.scatter(spec_emb[mask,0], spec_emb[mask,1], s=5, label=i)
plt.legend();

In [None]:
for i,ids in enumerate(ids_per_layer):
    plt.scatter(spec_emb[ids,0], spec_emb[ids,1], s=5, label=i)
plt.legend();

In [None]:
l1 = dict(graph_mlp.named_modules())['hidden.0']
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(20, 14))
for i in range(16):
    sns.heatmap(list(l1.parameters())[0][i,:].reshape(28, 28).cpu().detach().numpy(), xticklabels=False, yticklabels=False, center=0, ax=axes.flatten()[i])

plt.tight_layout()