In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os 
import sys
import json
import torch

sys.path.append("..")
%matplotlib inline

In [None]:
BASE_PATH = "../test/logs/Sun_Jun__9_16_15_40_201987405"
MODEL_BASE_PATH = os.path.join(BASE_PATH, "models")
CONFIG_PATH = os.path.join(BASE_PATH, "config.json")
EPOCH = 21

config = json.load(open(CONFIG_PATH, "rt"))

dataset = config['data']['name']
n_features = config['data'][dataset]['dims']
n_atoms = config['data'][dataset]['atoms']
n_edge_types = config['model']['n_edge_types']
timesteps = config['data']['timesteps']

In [None]:
config

In [None]:
# Load encoder
from src.model.utils import load_weights, gen_fully_connected
from train import create_encoder

encoder = create_encoder(config)
#load_weights(encoder, os.path.join(MODEL_BASE_PATH, f"encoder_epoch{EPOCH}.pt"))

rel_rec, rel_send = gen_fully_connected(n_atoms)

In [None]:
from train import load_data
data_loaders = load_data(config)
test_loader = data_loaders['test_loader']

In [None]:
from src.model.utils import gumbel_softmax

data = next(iter(test_loader))[0][:, :, :timesteps, :]

logits = encoder(data, rel_rec, rel_send)
edges = gumbel_softmax(logits, tau=0.5, hard=True)

In [None]:
edges.size()

In [None]:
from utils import get_offdiag_indices

n_samples = edges.size(0)
indices = get_offdiag_indices(n_atoms)

graphs = np.empty((n_samples, n_edge_types, n_atoms, n_atoms))
k = None

for sample in range(n_samples):
    for edge_type in range(n_edge_types):
        graph = edges[sample, :, edge_type]
        fully_connected = torch.zeros(n_atoms * n_atoms)
        fully_connected[indices] = graph
        adjacency_matrix = fully_connected.view(n_atoms, n_atoms).detach().numpy()
        graphs[sample, edge_type, :, :] = adjacency_matrix

In [None]:
# Plot
SAMPLE = 0

fig, axes = plt.subplots(nrows=1, ncols=n_edge_types)
for i in range(n_edge_types):
    im = axes[i].imshow(graphs[SAMPLE, i, :, :], cmap='gray', interpolation=None)
fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6)

In [None]:
# Averages
fig, axes = plt.subplots(nrows=1, ncols=n_edge_types)
for i in range(n_edge_types):
    im = axes[i].imshow(np.mean(graphs[:, i, :, :], axis=0), cmap='gray', interpolation=None)
fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6)