In [None]:
import jraph
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import optax

import sys
sys.path.append('../')

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import input_pipeline_tf
import models
import loss
import configs.platonic_solids.nequip as nequip_config
import analyses.analysis as analysis

In [None]:
config = nequip_config.get_config()
config.target_position_predictor.num_channels = 1
config.max_ell = 5
config

In [None]:
rng = jax.random.PRNGKey(0)
datasets = input_pipeline_tf.get_unbatched_platonic_solids_datasets(rng, config)

In [None]:
for graph in datasets['train'].take(1).as_numpy_iterator():
    print(graph)

In [None]:
net = models.create_model(config, run_in_evaluation_mode=False)

In [None]:
def loss_fn(params, padded_graph):
    preds = net.apply(params, None, padded_graph)
    _, (_, position_loss) = loss.generation_loss(preds, padded_graph, **config.loss_kwargs)
    return position_loss[0]


@jax.jit
def train_on_single_graph(padded_graph, rng, num_steps: int, lr: float):
    # Optimize parameters.
    tx = optax.adam(lr)
    init_params = net.init(rng, padded_graph)
    init_opt_state = tx.init(init_params)

    def update_fn(params, opt_state):
        grads = jax.grad(loss_fn)(params, padded_graph)
        updates, opt_state = tx.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state

    optimized_params, _ = jax.lax.fori_loop(0, num_steps, lambda i, args: update_fn(*args), (init_params, init_opt_state))
    preds = net.apply(optimized_params, None, padded_graph)
    preds = jax.tree_map(lambda x: x[0], preds)

    target_position_distance = jnp.linalg.norm(padded_graph.globals.target_positions, axis=-1)[0]
    radius_index = jnp.argmin(jnp.abs(models.RADII - target_position_distance))
    position_coeffs = preds.globals.position_coeffs
    return preds, position_coeffs[0, radius_index]


In [None]:
print(jax.tree_map(jnp.shape, preds.globals.target_positions))

In [None]:
padded_graph = jraph.pad_with_graphs(graph, 20, 400, 2)
preds, position_coeffs = train_on_single_graph(padded_graph, rng, 10, 1e-3)

In [None]:
# Compute distance matrix.
distance_matrix = jnp.sqrt(jnp.sum((graph.nodes.positions[:, None, :] - graph.nodes.positions[None, :, :]) ** 2, axis=-1))
distance_matrix

In [None]:
graph.nodes.positions, graph.globals.target_positions

In [None]:
jax.tree_map(jnp.shape, graph)