In [1]:
import sys
import torch
import numpy as np
import pickle as pkl
sys.path.append('../../../')


from experiments.supervised.product_manifold.script import product_manifold, plot_networks
from riemannian_geometry.computations.pullback_metric import pullback_all_metrics
from models.supervised.mlp.model import MLP

models_path = "../../../models/supervised/mlp/saved_models"


In [2]:
np.random.seed(2)
torch.manual_seed(2)

<torch._C.Generator at 0x7fce512ac4b0>

In [3]:
size = "vanilla"
mode = "blobs"
epoch = 199
model = MLP(2,7,4,4)
with open(f'{models_path}/{size}/mlp_{mode}/dataset.pkl', 'rb') as f:
	dataset = pkl.load(f)

full_path = f'{models_path}/{size}/mlp_{mode}/model_{epoch}.pth'
model.load_state_dict(torch.load(full_path))

model.eval()


MLP(
  (layers): ModuleList(
    (0): Layer(
      (act_func): Tanh()
      (linear_map): Linear(in_features=2, out_features=4, bias=True)
    )
    (1-5): 5 x Layer(
      (act_func): Tanh()
      (linear_map): Linear(in_features=4, out_features=4, bias=True)
    )
    (6): Layer(
      (act_func): Sigmoid()
      (linear_map): Linear(in_features=4, out_features=4, bias=True)
    )
  )
)

In [4]:
N=50
wrt = "output_wise"
sigma = 0.05

X = torch.from_numpy(dataset.X).float()
labels = dataset.y


model.forward(X, save_activations=True)

activations = model.get_activations()
activations_np = [a.detach().numpy() for a in activations]
g, dg, ddg, surface = pullback_all_metrics(model, activations, N, wrt=wrt, method="manifold", sigma=sigma, normalised=False)
	

In [5]:
layer = 0

indices = np.arange(len(activations_np[layer]))
surface_ = surface[layer][indices].copy()
activations_np_ = activations_np[layer][indices].copy()
g_ = g[layer][indices].copy()

V = np.linalg.norm(g_, axis=(1,2), ord="fro")

components_pin, graphs_pin, quantiles = product_manifold(surface_, V, plot=False, size=size, mode=mode, use_pin=True)
components_no_pin, graphs_no_pin, quantiles = product_manifold(surface_, V, plot=False, size=size, mode=mode, use_pin=False)


K=2 -> 684
Nodes remaining: {0}
K=3 -> 217
Nodes remaining: {100}
K=4 -> 13
Nodes remaining: {152}
K=5 -> 2
Nodes remaining: {528}
K=6 -> 1
{0: {320, 743, 202, 121, 862}, 1: {744, 779, 203, 238, 442}, 2: {842, 780, 301, 239, 606}, 3: {480, 365, 244, 539, 607}, 4: {608, 64, 335, 245, 540}, 5: {609, 246, 952, 474, 411}, 6: {187, 594, 729, 953, 412}, 7: {595, 730, 954, 188, 413}, 8: {955, 596, 731, 189, 414}, 9: {416, 732, 597, 956, 190}, 10: {464, 823, 282, 733, 191}, 11: {192, 465, 824, 283, 734}, 12: {193, 466, 825, 284, 735}, 13: {736, 194, 467, 826, 285}, 14: {737, 195, 468, 827, 286}, 15: {738, 196, 469, 828, 287}, 16: {288, 107, 561, 919, 829}, 17: {289, 108, 562, 761, 830}, 18: {290, 109, 563, 350, 831}, 19: {832, 291, 110, 564, 351}, 20: {381, 82, 565, 923, 893}, 21: {683, 83, 566, 924, 894}, 22: {801, 84, 567, 925, 895}, 23: {746, 205, 568, 317, 926}, 24: {747, 206, 569, 318, 927}, 25: {928, 387, 748, 435, 570}, 26: {929, 749, 207, 436, 571}, 27: {930, 750, 209, 437, 572}, 28: {

In [20]:
import matplotlib.pyplot as plt
import os
import networkx as nx
def plot_networks(graphs, quantiles, mode, size, activations, labels, surface):
    if not os.path.exists(f'figures/{mode}/{size}'):
        os.makedirs(f'figures/{mode}/{size}')
    pos = {i: surface[i] for i in range(len(surface))}
    base_nodes = graphs[0].nodes

    for Graph, q in zip(graphs, quantiles):
        fig, ax = plt.subplots(figsize=(16,16))
        removed_nodes = list(set(base_nodes) - set(Graph.nodes))
        nx.draw(Graph, pos=pos, node_size=5, ax=ax)
        ax.scatter(activations[:,0], activations[:,1], c=labels, cmap="RdBu_r")
        if len(removed_nodes) > 0:
            ax.scatter(surface[removed_nodes,0], surface[removed_nodes,1], c="red", s=30)
        ax.set_title(f"Quantile {q}")

        plt.savefig(f"figures/{mode}/{size}/p_out_quantile_{q}.png")
        plt.close(fig)

In [6]:
plot_networks(graphs_no_pin, quantiles, mode=mode, size=size, activations=activations_np_, labels=dataset.y, surface=surface_)
