In [None]:
import jraph
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import numpy as np


import sys
sys.path.append('..')


In [None]:
log_radius_weights = e3nn.IrrepsArray("0e", jnp.asarray([1.]))
log_angular_coeffs = e3nn.IrrepsArray("0e + 1o", jnp.arange(4).astype(float))
x = e3nn.concatenate([log_radius_weights, log_angular_coeffs], axis=0)
x, x.simplify(), e3nn.sum(x)

In [None]:
log_radius_weights = e3nn.IrrepsArray("0e", jnp.arange(64).reshape((64, 1)))
log_angular_coeffs = e3nn.IrrepsArray("0e + 1o", jnp.arange(4))
log_angular_coeffs = e3nn.IrrepsArray(log_angular_coeffs.irreps, jnp.tile(log_angular_coeffs.array, (64, 1)))
print(log_angular_coeffs.shape)
e3nn.concatenate([log_radius_weights, log_angular_coeffs], axis=1)

In [None]:
coeffs = e3nn.IrrepsArray("0e + 1o", jnp.asarray([1.0, 1.0, 2.0, 3.0]))
coeffs_on_grid = e3nn.to_s2grid(coeffs, res_beta=30, res_alpha=59, quadrature="soft").grid_values

scaled_coeffs = coeffs * 2.0
scaled_coeffs_on_grid = e3nn.to_s2grid(scaled_coeffs, res_beta=30, res_alpha=59, quadrature="soft").grid_values

scaled_coeffs_on_grid / coeffs_on_grid

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import train
import models
import datatypes
import analyses.generate_plots as generate_plots
import input_pipeline_tf
import configs.mace as mace
import analyses.analysis as analysis

In [None]:
config = mace.get_config()
print(config.train_molecules)
config.train_molecules = (0, 2976)
datasets = input_pipeline_tf.get_datasets(None, config)
for step, _ in enumerate(datasets['train']):
    if step % 50000 == 0:
        print(step)
print(step)
    

In [None]:
basedir = "../potato_workdirs/extras/sample_complexity"
results = analysis.get_results_as_dataframe(generate_plots.ALL_MODELS, generate_plots.ALL_METRICS, basedir)
results['val'][results['val']['model'] == 'mace'].sort_values(by='max_l')

In [None]:
results['val'].to_clipboard()

In [None]:
def eval():
    for x in datasets["val"].take(10).as_numpy_iterator():
        mask = jraph.get_graph_padding_mask(x)
        print(x.n_node[mask].sum())

eval()


In [None]:
import tensorflow as tf

ds = tf.data.Dataset.range(10)

In [None]:
for x in ds.take(3).as_numpy_iterator():
    print(x)

In [None]:
config = mace.get_config()
datasets = input_pipeline_tf.get_raw_qm9_datasets(config)
for x in datasets["val"].take(10).as_numpy_iterator():
    print(x['target_positions'].sum())

In [None]:
ds = datasets['val'].map(
    input_pipeline_tf._convert_to_graphstuple,
    num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=True,
)

ds_b = jraph.dynamically_batch(graphs_tuple_iterator=iter(ds),
                                n_node=config.max_n_nodes,
                                n_edge=config.max_n_edges,
                                n_graph=config.max_n_graphs)

In [None]:
for step, x in enumerate(ds_b):
    mask = jraph.get_graph_padding_mask(x)
    print(x.n_node[mask].sum())
    if step == 10:
        break

In [None]:
import functools

example_graph = next(ds.as_numpy_iterator())
example_padded_graph = jraph.pad_with_graphs(
    example_graph, n_node=config.max_n_nodes, n_edge=config.max_n_edges, n_graph=config.max_n_graphs
)
padded_graphs_spec = input_pipeline_tf._specs_from_graphs_tuple(example_padded_graph)

# Batch and pad each split separately.
batching_fn = functools.partial(
    jraph.dynamically_batch,
    graphs_tuple_iterator=iter(ds),
    n_node=config.max_n_nodes,
    n_edge=config.max_n_edges,
    n_graph=config.max_n_graphs,
)
ds_tf = tf.data.Dataset.from_generator(
    batching_fn, output_signature=padded_graphs_spec
)
ds_tf = ds_tf.take(100).cache()

In [None]:
for step, x in enumerate(ds_tf.as_numpy_iterator()):
    mask = jraph.get_graph_padding_mask(x)
    print(x.n_node[mask].sum())
    if step == 10:
        break

In [None]:
# Mimic what we do in train.py.
config = mace.get_config()

config.max_n_graphs = 32
rng = jax.random.PRNGKey(0)
rng, dataset_rng = jax.random.split(rng)

# Obtain graphs.
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)
train_iter = datasets["train"].as_numpy_iterator()
init_graphs = next(train_iter)

# Set up dummy variables to obtain the structure.
rng, init_rng = jax.random.split(rng)
net = train.create_model(config, run_in_evaluation_mode=False)
params = jax.jit(net.init)(init_rng, init_graphs)

#ds = ds.shuffle(buffer_size=5, reshuffle_each_iteration=False)

In [None]:
for x in datasets["train_eval_final"].as_numpy_iterator():
    mask = jraph.get_graph_padding_mask(x)
    print(x.n_node[mask].sum())

In [None]:
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
ds = ds.shuffle(buffer_size=5, reshuffle_each_iteration=False)

In [None]:
for x in ds.as_numpy_iterator():
    print(x)

In [None]:
print("padding")
print("nodes", config.max_n_nodes)
print("edges", config.max_n_edges)
print("graphs", config.max_n_graphs)

In [None]:
count = 0
for step, graphs in enumerate(datasets["train"].as_numpy_iterator()):
    if step % 1000 == 1:
        print(step, count / step)
    if step == 10000:
        break

    graphs = jax.tree_map(jnp.array, graphs)
    graphs = datatypes.Fragments.from_graphstuple(graphs)
    count += jraph.get_graph_padding_mask(graphs).sum()

print(count / step)


In [None]:
graphs = jraph.GraphsTuple(
    nodes=jnp.array([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]),
    edges=jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
    globals=jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
    n_node=jnp.array([2, 1, 1]),
    n_edge=jnp.array([1, 1, 1]),
    senders=jnp.array([0, 0, 0]),
    receivers=jnp.array([0, 0, 0]),
)

In [None]:
num_graphs = 4
rng, *species_rngs = jax.random.split(jax.random.PRNGKey(0), num_graphs + 1)

target_species_probs = jnp.ones((num_graphs, models.NUM_ELEMENTS)) / models.NUM_ELEMENTS
target_species = jax.vmap(lambda key, p: jax.random.choice(
    key, models.NUM_ELEMENTS, p=p))((species_rngs,), target_species_probs)


In [None]:
e3nn.IrrepsArray("10x0e", jnp.arange(10))

In [None]:

myhost

In [None]:
# Get results.
results = analysis.get_results_as_dataframe(generate_plots.ALL_MODELS, generate_plots.ALL_METRICS, "../potato_workdirs/v4/")

In [None]:
results['val'].sort_values('total_loss')

In [None]:
import re
seed, start, end = [int(val) for val in re.findall(r'\d+', "fragments_seed01_from130944_to133920")]
start, end

In [None]:
def segment_sample(probabilities, segment_ids, num_segments, rng):
    """Sample from a categorical distribution across each segment.
    Args:
        segment_ids: A 1D array of segment ids.
        probs: A 1D array of probabilities.
    Returns:
        A 1D array of samples.
    """
    def sample_for_segment(rng, i):
        return jax.random.choice(rng, node_indices, p=jnp.where(i == segment_ids, probabilities, 0.))
    
    node_indices = jnp.arange(len(segment_ids))
    rngs = jax.random.split(rng, num_segments)
    return jax.vmap(sample_for_segment)(rngs, jnp.arange(num_segments))

In [None]:
focus_logits = graphs.nodes.sum(axis=1)
probs = jraph.partition_softmax(focus_logits, graphs.n_node)
print(probs)
for seed in range(100):
    print(segment_sample(probs, jnp.asarray([0, 0, 1, 2]), 3, jax.random.PRNGKey(seed)))

In [None]:
nodes = jnp.asarray([2, 3, 1, 4])
segment_max = e3nn.scatter_max(nodes, nel=graphs.n_node)
segment_max_expanded = e3nn.scatter_max(nodes, map_back=True, nel=graphs.n_node)
print(segment_max_expanded)
# segment_max_expanded = jnp.asarray([segment_max[0], segment_max[0], segment_max[1], segment_max[2]])

expected = 0 + jnp.log(1 + e3nn.scatter_sum(jnp.exp(nodes - 0), nel=graphs.n_node))
computed = segment_max + jnp.log(jnp.exp(-segment_max) + e3nn.scatter_sum(jnp.exp(nodes - segment_max_expanded), nel=graphs.n_node))

expected, computed

In [None]:
import analysis


In [None]:
analysis.load_from_workdir("/Users/ameyad/Documents/spherical-harmonic-net/potato_workdirs/workdirs/mace/interactions=1/l=0/channels=32")

In [None]:
lmax = 3
irreps = e3nn.Irreps(e3nn.Irrep.iterator(lmax))
e3nn.IrrepsArray(irreps=irreps, array=jnp.ones((10, irreps.dim)))

In [None]:
e3nn.scatter_sum(data=graphs.nodes, nel=graphs.n_node)