In [1]:
import torch
from torch.utils import data

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

## Dummy Data Generation

### Parameters
Lets generate some random data to demonstrate our setup

In [2]:
N_SAMPLES = 99 # Number of training samples to generate
N_ATOMS = 5 # E.g. number of agents in a simulation
N_STEPS = 10 # Length of time series (e.g. trajectory length)
N_FEATURES = 3 # Number of features per atom and timestep

### Generate Random Data

In [3]:
data_loaders = dict(
    train_loader= data.DataLoader(data.TensorDataset(torch.rand(N_SAMPLES, N_ATOMS, N_STEPS, N_FEATURES))),
    valid_loader = data.DataLoader(data.TensorDataset(torch.rand(N_SAMPLES, N_ATOMS, N_STEPS, N_FEATURES))),
    test_loader = data.DataLoader(data.TensorDataset(torch.rand(N_SAMPLES, N_ATOMS, N_STEPS, N_FEATURES)))
)

## Model Definition
### Model Hyperparameters

In [4]:
N_EDGE_TYPES = 3
DIM_ENCODER_HIDDEN = 20
DIM_DECODER_HIDDEN = 20

In [8]:
from src.model.modules import MLPEncoder, RNNDecoder

encoder = MLPEncoder(N_STEPS * N_FEATURES, DIM_ENCODER_HIDDEN, N_EDGE_TYPES)
decoder = RNNDecoder(n_in_node=N_FEATURES, edge_types=N_EDGE_TYPES, n_hid=DIM_DECODER_HIDDEN)

### Training Setup

In [None]:
from src.trainer import generate_config
config = generate_config(
    n_edges=N_EDGE_TYPES,
    n_atoms=N_ATOMS,
    epochs=3,
    early_stopping_patience=2,
    temp=0.5 # Softmax Temperature for latent graph
)

In [None]:
from src.trainer import Trainer

trainer = Trainer(encoder=encoder,
                decoder=decoder,
                data_loaders=data_loaders,
                config=config)
trainer.train()

## Evaluation

In [None]:
from src.evaluation import Evaluator

In [None]:
evaluator = Evaluator(
    encoder=encoder, 
    decoder=decoder, 
    data_loader=data_loaders['test_loader'],
    config=config)

In [None]:
print(evaluator.test())

In [None]:
graphs, mean_graph = evaluator._extract_latent_graphs()

In [None]:
print(mean_graph.shape)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(1, N_EDGE_TYPES, figsize=(20, 10))
for i in range(N_EDGE_TYPES):
    ax[i].imshow(mean_graph[:,:,i], cmap='gray')
    ax[i].set_title(f"Edge type {i}", {"fontsize":20})