In [None]:
import torch.nn as nn
from src import (train_sdf, get_sdf_data_loader, deep_mind_loss, plot_sdf_results, plot_sdf_results_over_line,
                 EncodeProcessDecode, EncodeProcessDecodeNEW, GraphNetworkIndependentBlock, GraphNetworkBlock)

In [None]:
# get data
n_objects = 600
data_folder = "../mesh/"
batch_size = 5
edge_method = 'edge' # or 'proximity'
edge_params = {'radius': 0.05}
train_data, test_data = get_sdf_data_loader(n_objects, data_folder, batch_size, eval_frac=0.1,
                                            edge_method=edge_method, edge_params=edge_params)

In [None]:
n_edge_feat_in, n_edge_feat_out = 5, 1
n_node_feat_in, n_node_feat_out = 3, 1
n_global_feat_in, n_global_feat_out = 3, 3
mlp_latent_size = 64
num_processing_steps = 5
model = EncodeProcessDecodeNEW(n_edge_feat_in=n_edge_feat_in, n_edge_feat_out=n_edge_feat_out,
                               n_node_feat_in=n_node_feat_in, n_node_feat_out=n_node_feat_out,
                               n_global_feat_in=n_global_feat_in, n_global_feat_out=n_global_feat_out,
                               mlp_latent_size=mlp_latent_size, num_processing_steps=num_processing_steps,
                               encoder=GraphNetworkIndependentBlock, decoder=GraphNetworkIndependentBlock,
                               processor=GraphNetworkBlock, output_transformer=GraphNetworkIndependentBlock,
                               full_output=True
                               )

In [None]:
# train parameters
gamma       = 0.25
lr_0        = 0.001
n_epoch     = 1500
step_size   = 500
print_every = 25
save_name   = "deep_mind_graph_" + edge_method + "_nn_" + str(mlp_latent_size) + "_nlayers_" + str(num_processing_steps)

In [None]:
# loss function and train
loss_funcs = [deep_mind_loss]
train_sdf(model, train_data, test_data, loss_funcs, n_epoch=n_epoch, print_every=print_every,
          save_name=save_name, lr_0=lr_0, step_size=step_size, gamma=gamma)

In [None]:
# visualization
data_loader, _ = get_sdf_data_loader(10, data_folder, 1, eval_frac=0, edge_method=edge_method, edge_params=edge_params)
output_func = lambda x: x[-1][1].numpy().reshape(-1)
plot_sdf_results(model, data_loader, save_name=save_name, output_func=output_func, levels=None)#[-0.1, 0., 0.1])