# Data block

### View module

In [1]:
from data.config import DataConfig

data_config = DataConfig()
data_config.set_train_valid_dataset()

data_config.get_dataset_paths()

data_config.set_test_dataset()
data_config.get_dataset_paths()

data_config.set_train_valid_dataset()
node_ds_paths, edge_ds_paths = data_config.get_dataset_paths()


In [2]:
import os

for paths in node_ds_paths['H']:
   last_element = os.path.basename(paths)
   print(last_element) 

vel
pos


Load dataset

In [3]:
from data.load import load_spring_particle_data

train_loader, valid_loader, test_loader = load_spring_particle_data(node_ds_paths, edge_ds_paths)

dataiter = iter(train_loader)
data = next(dataiter)

n_timesteps = data[0].shape[2]
n_dims = data[0].shape[3]

In [4]:
import numpy as np
import torch

def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot

# Generate off-diagonal interaction graph
off_diag = np.ones([5, 5]) - np.eye(5)
print("Off-diagonal interaction graph:")
print(off_diag)

rec_rel = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
send_rel = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rec_rel = torch.FloatTensor(rec_rel)
send_rel = torch.FloatTensor(send_rel)

Off-diagonal interaction graph:
[[0. 1. 1. 1. 1.]
 [1. 0. 1. 1. 1.]
 [1. 1. 0. 1. 1.]
 [1. 1. 1. 0. 1.]
 [1. 1. 1. 1. 0.]]


### Enocder and Decoder

In [5]:
from topology_estimation.config import TopologyEstimatorConfig
from topology_estimation.encoder_blocks import Encoder
from torchinfo import summary

par = TopologyEstimatorConfig()
par.set_encoder_params()

encoder = Encoder(n_timesteps=n_timesteps, 
                  n_dims=n_dims,
                  pipeline=par.encoder_pipeline, 
                  n_edge_types=par.n_edge_types, 
                  is_residual_connection=par.is_residual_connection,
                  edge_emd_configs=par.edge_emb_configs_enc, 
                  node_emd_configs=par.node_emb_configs_enc, 
                  drop_out_prob=par.dropout_prob_enc,
                  batch_norm=par.batch_norm_enc, 
                  attention_output_size=par.attention_output_size)

encoder.set_input_graph(rec_rel, send_rel)

print(summary(encoder, (64, 5, n_timesteps, n_dims)))

Layer (type:depth-idx)                   Output Shape              Param #
Encoder                                  [64, 20, 2]               --
├─ModuleDict: 1-1                        --                        --
│    └─MLP: 2-1                          [64, 5, 8]                --
│    │    └─ModuleList: 3-1              --                        15,352
│    └─MLP: 2-2                          [64, 5, 8]                --
│    │    └─ModuleList: 3-2              --                        3,320
│    └─MLP: 2-3                          [64, 20, 8]               --
│    │    └─ModuleList: 3-3              --                        3,320
│    └─MLP: 2-4                          [64, 5, 8]                --
│    │    └─ModuleList: 3-4              --                        3,320
│    └─MLP: 2-5                          [64, 5, 8]                --
│    │    └─ModuleList: 3-5              --                        3,320
│    └─MLP: 2-6                          [64, 20, 8]               --

In [6]:
# Decoder
from topology_estimation.decoder_blocks import Decoder
par.set_decoder_params()
decoder = Decoder(n_dim=n_dims,
                  msg_out_size=par.msg_out_size,
                  n_edge_types=par.n_edge_types,
                  skip_first=par.skip_first_edge_type,
                  edge_mlp_config=par.edge_mlp_config_dec,
                  recurrent_emd_type=par.recurrent_emd_type,
                  out_mlp_config=par.out_mlp_config_dec,
                  do_prob=par.dropout_prob_dec,
                  is_batch_norm=par.is_batch_norm_dec)

# generate random edge matrix
edge_matrix = torch.rand((64, 20, 2))
decoder.set_input_graph(rec_rel, send_rel)
decoder.set_edge_matrix(edge_matrix)
decoder.set_run_params()

print(summary(decoder, (64, 5, n_timesteps, n_dims)))

Layer (type:depth-idx)                   Output Shape              Param #
Decoder                                  [64, 5, 48, 4]            --
├─GRU: 1-1                               [64, 5, 64]               --
│    └─Linear: 2-1                       [64, 5, 64]               320
│    └─Linear: 2-2                       [64, 5, 64]               4,160
│    └─Linear: 2-3                       [64, 5, 64]               320
│    └─Linear: 2-4                       [64, 5, 64]               4,160
│    └─Linear: 2-5                       [64, 5, 64]               320
│    └─Linear: 2-6                       [64, 5, 64]               4,160
├─MLP: 1-2                               [64, 5, 64]               --
│    └─ModuleList: 2-383                 --                        (recursive)
│    │    └─Linear: 3-1                  [64, 5, 64]               4,160
│    │    └─BatchNorm1d: 3-2             [320, 64]                 128
│    │    └─Tanh: 3-3                    [64, 5, 64]        