In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax, jax.numpy as jnp, optax, numpy as np

from mgd.dataset.dataloader import GraphBatchLoader
from mgd.model.embeddings import GraphEmbedder
from mgd.model.denoiser import MPNNDenoiser
from mgd.model.diffusion_model import GraphDiffusionModel
from mgd.diffusion.schedules import cosine_beta_schedule
from mgd.training import create_train_state, train_loop



In [None]:
# Hyperparameters
batch_size = 64
n_timesteps = 1000
atom_dim, hybrid_dim, cont_dim = 32, 16, 16
node_dim, edge_dim, mess_dim, time_dim = 32, 16, 32, 32
lr = 1e-3
num_epochs = 2  # demo

In [4]:
# Data
splits = dict(np.load("../data/processed/qm9_splits.npz"))
data = dict(np.load("../data/processed/qm9_dense.npz"))
train_loader = GraphBatchLoader(data, indices=splits["train"], batch_size=batch_size, key=jax.random.PRNGKey(0))
batch = next(iter(train_loader))

In [5]:
# Model
embedder = GraphEmbedder(
    atom_embed_dim=atom_dim,
    hybrid_embed_dim=hybrid_dim,
    cont_embed_dim=cont_dim,
    node_hidden_dim=node_dim,
    edge_embed_dim=edge_dim,
    edge_hidden_dim=edge_dim,
)
denoiser = MPNNDenoiser(node_dim=node_dim, edge_dim=edge_dim, mess_dim=mess_dim, time_dim=time_dim)
schedule = cosine_beta_schedule(n_timesteps)
model = GraphDiffusionModel(embedder=embedder, denoiser=denoiser, schedule=schedule)

In [6]:
# Optimizer and state
tx = optax.adam(lr)
rng = jax.random.PRNGKey(42)
init_batch = next(iter(train_loader))
state = create_train_state(model, init_batch, tx, rng)

In [7]:
import operator

def n_parameters(params):
    s = jax.tree_util.tree_map(lambda x: jnp.prod(jnp.array(x.shape)), params)
    return jax.tree.reduce_associative(operator.add, s)

print('Total number of parameters: ', n_parameters(state.params))

jax.tree.map(lambda x: x.shape, state.params)

Total number of parameters:  14624


{'denoiser': {'backbone': {'mpnn_0': {'LayerNorm_0': {'bias': (32,),
     'scale': (32,)},
    'LayerNorm_1': {'bias': (16,), 'scale': (16,)},
    'LayerNorm_2': {'bias': (32,), 'scale': (32,)},
    'edge_mlp': {'dense_0': {'bias': (16,), 'kernel': (80, 16)},
     'dense_1': {'bias': (16,), 'kernel': (16, 16)}},
    'mess_mlp': {'dense_0': {'bias': (32,), 'kernel': (48, 32)},
     'dense_1': {'bias': (32,), 'kernel': (32, 32)}},
    'node_mlp': {'dense_0': {'bias': (32,), 'kernel': (32, 32)},
     'dense_1': {'bias': (32,), 'kernel': (32, 32)}}},
   'time_embedding': {'Dense_0': {'bias': (32,), 'kernel': (32, 32)},
    'Dense_1': {'bias': (32,), 'kernel': (32, 32)},
    'Dense_2': {'bias': (16,), 'kernel': (32, 16)}}},
  'edge_head': {'bias': (16,), 'kernel': (16, 16)},
  'node_head': {'bias': (32,), 'kernel': (32, 32)}},
 'embedder': {'edge_embedding': {'LayerNorm_0': {'bias': (16,),
    'scale': (16,)},
   'edge_embedding': {'embedding': (5, 16)},
   'fuse': {'bias': (16,), 'kernel':

In [8]:
batch_latent = state.model.apply(
      {"params": state.params},
      batch,
      method=state.model.encode,  # GraphDiffusionModel.encode -> GraphLatent
  )
batch_latent.node.mean(axis=0).mean(), batch_latent.node.std(axis=0).mean(), batch_latent.edge.mean(axis=0).mean(), batch_latent.edge.std(axis=0).mean()

(Array(9.531605e-08, dtype=float32),
 Array(1.2165462, dtype=float32),
 Array(2.6490756e-07, dtype=float32),
 Array(0.8285135, dtype=float32))

In [None]:
# Train
rng, loop_rng = jax.random.split(rng)
state, history = train_loop(state, train_loader, num_epochs=num_epochs, rng=loop_rng, log_every=10)

print("Final epoch metrics:", history[-1])

  0%|          | 0/2 [00:00<?, ?it/s]

epoch 1 step 10: loss=47.4578
epoch 1 step 20: loss=47.1974
epoch 1 step 30: loss=46.3313
epoch 1 step 40: loss=45.6620
epoch 1 step 50: loss=44.9829
epoch 1 step 60: loss=44.6170
epoch 1 step 70: loss=43.4967
epoch 1 step 80: loss=43.1435
epoch 1 step 90: loss=42.1102
epoch 1 step 100: loss=41.4489
epoch 1 step 110: loss=40.2686
epoch 1 step 120: loss=40.1459
epoch 1 step 130: loss=37.9873
epoch 1 step 140: loss=36.8564
epoch 1 step 150: loss=36.2764
epoch 1 step 160: loss=36.2488
epoch 1 step 170: loss=35.6040
epoch 1 step 180: loss=34.7582
epoch 1 step 190: loss=32.6646
epoch 1 step 200: loss=31.1810
epoch 1 step 210: loss=31.3534
epoch 1 step 220: loss=30.2289
epoch 1 step 230: loss=28.3020
epoch 1 step 240: loss=29.1624
epoch 1 step 250: loss=27.8789
epoch 1 step 260: loss=27.2817
epoch 1 step 270: loss=26.5716
epoch 1 step 280: loss=26.6128
epoch 1 step 290: loss=25.2432
epoch 1 step 300: loss=24.0106
epoch 1 step 310: loss=24.8673
epoch 1 step 320: loss=24.5211
epoch 1 step 330:

In [44]:
batch_latent.edge.std(axis=0).mean() * 20 * 4, batch_latent.node.std(axis=0).mean() * 20

(Array(1.0685542, dtype=float32), Array(1.0747564, dtype=float32))

In [None]:
batch_latent.edge.mean()

Array(-0.20964783, dtype=float32)