# Exploration of **Training** EDM: E(3) Equivariant Diffusion Model for Molecule Generation in 3D.
[Github repo](https://github.com/ehoogeboom/e3_diffusion_for_molecules)

## Main file `main_qm9.py`
The training is triggered with a command like:
```bash
python main_qm9.py --n_epochs 3000 --exp_name edm_qm9 --n_stability_samples 1000 --diffusion_noise_schedule polynomial_2 --diffusion_noise_precision 1e-5 --diffusion_steps 1000 --diffusion_loss_type l2 --batch_size 64 --nf 256 --n_layers 9 --lr 1e-4 --normalize_factors [1,4,10] --test_epochs 20 --ema_decay 0.9999
```
Use `argparse` with several default values in case user doesn't explicitly passes them. 

### 1. Dataset

In [1]:
import argparse
parser = argparse.ArgumentParser(description='E3Diffusion')
parser.add_argument('--dataset', default='qm9')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--filter_n_atoms', type=int, default=None)
parser.add_argument('--datadir', type=str, default='qm9/temp')
parser.add_argument('--remove_h', action='store_true')
parser.add_argument('--include_charges', type=eval, default=True)
args = parser.parse_args(args=[])


In [2]:
args.dataset

'qm9'

In [3]:
from configs.datasets_config import get_dataset_info
from qm9 import dataset

In [4]:
dataset_info = get_dataset_info(args.dataset, args.remove_h)
dataset_info.keys()

dict_keys(['name', 'atom_encoder', 'atom_decoder', 'n_nodes', 'max_n_nodes', 'atom_types', 'distances', 'colors_dic', 'radius_dic', 'with_h'])

In [5]:
# Retrieve QM9 dataloaders
dataloaders, charge_scale = dataset.retrieve_dataloaders(args)
data_dummy = next(iter(dataloaders['train']))

Take a look at the datasets:

In [6]:
print(charge_scale)
print(dataloaders.keys())
for key in dataloaders: print(f"{key}: {len(dataloaders[key])}")

tensor(9)
dict_keys(['train', 'valid', 'test'])
train: 782
valid: 139
test: 103


And a look to a data-example `data_dummy`.

This one has 128 molecules, with different number of atoms. Tensors of `positions` has 27 dimensions that is the maximum number of atoms in a molecule in this dataset. In the cases where a molecule has less, the last element are 0's.

_conditioning_ is on a molecule property (i.e., single value per molecule): homo | lumo | alpha | gap | mu | Cv


In [7]:
for key in data_dummy: print(f"{key}: {data_dummy[key].shape}")

num_atoms: torch.Size([128])
charges: torch.Size([128, 27, 1])
positions: torch.Size([128, 27, 3])
index: torch.Size([128])
A: torch.Size([128])
B: torch.Size([128])
C: torch.Size([128])
mu: torch.Size([128])
alpha: torch.Size([128])
homo: torch.Size([128])
lumo: torch.Size([128])
gap: torch.Size([128])
r2: torch.Size([128])
zpve: torch.Size([128])
U0: torch.Size([128])
U: torch.Size([128])
H: torch.Size([128])
G: torch.Size([128])
Cv: torch.Size([128])
omega1: torch.Size([128])
zpve_thermo: torch.Size([128])
U0_thermo: torch.Size([128])
U_thermo: torch.Size([128])
H_thermo: torch.Size([128])
G_thermo: torch.Size([128])
Cv_thermo: torch.Size([128])
one_hot: torch.Size([128, 27, 5])
atom_mask: torch.Size([128, 27])
edge_mask: torch.Size([93312, 1])


In [13]:
print(data_dummy['atom_mask'][10])
print(data_dummy['num_atoms'][10])
print(data_dummy['edge_mask'][10])

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False, False, False, False, False])
tensor(22)
tensor([True])


# 2. EGNN model
❗️Conditioning is pertinent here

The model is composed of two submodules:

__1. EGNN_dynamics_QM9:__ Equivariant GNN

__2. EnVariationalDiffusion:__ Diffusion model

In [8]:
# pertinent arguments for the EGNN model
parser.add_argument("--conditioning", nargs='+', default=[],
                    help='arguments : homo | lumo | alpha | gap | mu | Cv' )
parser.add_argument('--condition_time', type=eval, default=True)

# pertinent for EGNN_dynamics_QM9
parser.add_argument('--model', type=str, default='egnn_dynamics')
parser.add_argument('--nf', type=int, default=128,
                    help='number of layers')
parser.add_argument('--n_layers', type=int, default=6,
                    help='number of layers')
parser.add_argument('--attention', type=eval, default=True,
                    help='use attention in the EGNN')
parser.add_argument('--tanh', type=eval, default=True,
                    help='use tanh in the coord_mlp')
parser.add_argument('--norm_constant', type=float, default=1,
                    help='diff/(|diff| + norm_constant)')
parser.add_argument('--inv_sublayers', type=int, default=1,
                    help='number of layers')
parser.add_argument('--sin_embedding', type=eval, default=False,
                    help='whether using or not the sin embedding')
parser.add_argument('--normalization_factor', type=float, default=1,
                    help="Normalize the sum aggregation of EGNN")
parser.add_argument('--aggregation_method', type=str, default='sum',
                    help='"sum" or "mean"')

# pertinent for EnVariationalDiffusion
parser.add_argument('--probabilistic_model', type=str, default='diffusion',
                    help='diffusion')
parser.add_argument('--diffusion_steps', type=int, default=500)
parser.add_argument('--diffusion_noise_schedule', type=str, default='polynomial_2',
                    help='learned, cosine')
parser.add_argument('--diffusion_noise_precision', type=float, default=1e-5)
parser.add_argument('--diffusion_loss_type', type=str, default='l2',
                    help='vlb, l2')
parser.add_argument('--normalize_factors', type=eval, default=[1, 4, 1],
                    help='normalize factors for [x, categorical, integer]')

args = parser.parse_args(args=[])
args.context_node_nf = 0 # for the moment no conditioning

import torch
device = torch.device("cpu")

In [9]:
from qm9.models import get_model

In [10]:
model, nodes_dist, prop_dist = get_model(args, device, dataset_info, dataloaders['train'])

Entropy of n_nodes: H[N] -2.475700616836548
alphas2 [9.99990000e-01 9.99982000e-01 9.99958001e-01 9.99918003e-01
 9.99862007e-01 9.99790014e-01 9.99702026e-01 9.99598046e-01
 9.99478076e-01 9.99342118e-01 9.99190176e-01 9.99022254e-01
 9.98838355e-01 9.98638484e-01 9.98422646e-01 9.98190846e-01
 9.97943090e-01 9.97679383e-01 9.97399731e-01 9.97104143e-01
 9.96792624e-01 9.96465182e-01 9.96121825e-01 9.95762562e-01
 9.95387400e-01 9.94996350e-01 9.94589420e-01 9.94166620e-01
 9.93727960e-01 9.93273451e-01 9.92803104e-01 9.92316930e-01
 9.91814941e-01 9.91297149e-01 9.90763566e-01 9.90214206e-01
 9.89649081e-01 9.89068205e-01 9.88471593e-01 9.87859258e-01
 9.87231215e-01 9.86587480e-01 9.85928068e-01 9.85252996e-01
 9.84562278e-01 9.83855933e-01 9.83133976e-01 9.82396427e-01
 9.81643302e-01 9.80874619e-01 9.80090398e-01 9.79290657e-01
 9.78475416e-01 9.77644695e-01 9.76798513e-01 9.75936891e-01
 9.75059851e-01 9.74167412e-01 9.73259599e-01 9.72336431e-01
 9.71397932e-01 9.70444124e-01 9.

In [11]:
model

EnVariationalDiffusion(
  (gamma): PredefinedNoiseSchedule()
  (dynamics): EGNN_dynamics_QM9(
    (egnn): EGNN(
      (embedding): Linear(in_features=7, out_features=128, bias=True)
      (embedding_out): Linear(in_features=128, out_features=7, bias=True)
      (e_block_0): EquivariantBlock(
        (gcl_0): GCL(
          (edge_mlp): Sequential(
            (0): Linear(in_features=258, out_features=128, bias=True)
            (1): SiLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): SiLU()
          )
          (node_mlp): Sequential(
            (0): Linear(in_features=256, out_features=128, bias=True)
            (1): SiLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
          )
          (att_mlp): Sequential(
            (0): Linear(in_features=128, out_features=1, bias=True)
            (1): Sigmoid()
          )
        )
        (gcl_equiv): EquivariantUpdate(
          (coord_mlp): Sequential(
            (0):

In [31]:
print(type(model))
print(type(model.dynamics))
print(model.dynamics.mode)
print(type(model.dynamics.egnn))
print(f"number of layers EGNN (all Equivariant blocks): {model.dynamics.egnn.n_layers}")
print(type(model.dynamics.egnn.e_block_0))
print(f"number of layers Equivariant block 0: {model.dynamics.egnn.e_block_0.n_layers}")
print(type(model.dynamics.egnn.e_block_0.gcl_0))
print(f"number of layers GCL in Equivariant block 0: {model.dynamics.egnn.e_block_0.n_layers}")


<class 'equivariant_diffusion.en_diffusion.EnVariationalDiffusion'>
<class 'egnn.models.EGNN_dynamics_QM9'>
egnn_dynamics
<class 'egnn.egnn_new.EGNN'>
number of layers EGNN (all Equivariant blocks): 6
<class 'egnn.egnn_new.EquivariantBlock'>
number of layers Equivariant block 0: 1
<class 'egnn.egnn_new.GCL'>


## 3. Optimizer

In [95]:
parser.add_argument('--lr', type=float, default=2e-4)
args = parser.parse_args(args=[])

In [55]:
from qm9.models import get_optim

In [57]:
optim = get_optim(args, model)

## 4. Additional models

In [96]:
parser.add_argument('--ema_decay', type=float, default=0.999,
                    help='Amount of EMA decay, 0 means off. A reasonable value is 0.999.')

args = parser.parse_args(args=[])

In [59]:
# data parallel model
model_dp = model

In [64]:
# ema
from equivariant_diffusion import utils as flow_utils
import copy
model_ema = copy.deepcopy(model)
ema = flow_utils.EMA(args.ema_decay)
model_ema_dp = model_ema

## 5. Training 👟

In [97]:
# Define some variables necessary in case conditioning
property_norms = None

import utils
gradnorm_queue = utils.Queue()
gradnorm_queue.add(3000)  # Add large value that will be flushed

parser.add_argument('--augment_noise', type=float, default=0)
parser.add_argument('--data_augmentation', type=eval, default=False, help='use attention in the EGNN')
parser.add_argument('--ode_regularization', type=float, default=1e-3)
parser.add_argument('--clip_grad', type=eval, default=True, help='True | False')
parser.add_argument('--test_epochs', type=int, default=10)
parser.add_argument('--visualize_every_batch', type=int, default=1e8,
                    help="Can be used to visualize multiple times per epoch")
parser.add_argument('--break_train_epoch', type=eval, default=False,
                    help='True | False')
parser.add_argument('--n_report_steps', type=int, default=1)
args = parser.parse_args(args=[])

In [102]:
# wandb
import wandb
wandb.init(**{'project': 'e3_diffusion-notebook', 'config': args})

In [None]:
from train_test import train_epoch
for epoch in range(0, 10):
    train_epoch(args=args, loader=dataloaders['train'], epoch=epoch, model=model, model_dp=model_dp,
                    model_ema=model_ema, ema=ema, device=device, dtype=torch.float32, property_norms=property_norms,
                    nodes_dist=nodes_dist, dataset_info=dataset_info,
                    gradnorm_queue=gradnorm_queue, optim=optim, prop_dist=prop_dist)

Epoch: 0, iter: 0/782, Loss 2.82, NLL: 2.82, RegTerm: 0.0, GradNorm: 48.3
Epoch: 0, iter: 1/782, Loss 2.89, NLL: 2.89, RegTerm: 0.0, GradNorm: 12.0
Epoch: 0, iter: 2/782, Loss 2.79, NLL: 2.79, RegTerm: 0.0, GradNorm: 19.5
Epoch: 0, iter: 3/782, Loss 2.89, NLL: 2.89, RegTerm: 0.0, GradNorm: 23.7
Epoch: 0, iter: 4/782, Loss 2.73, NLL: 2.73, RegTerm: 0.0, GradNorm: 18.5
Epoch: 0, iter: 5/782, Loss 2.79, NLL: 2.79, RegTerm: 0.0, GradNorm: 16.2
Epoch: 0, iter: 6/782, Loss 2.79, NLL: 2.79, RegTerm: 0.0, GradNorm: 10.6
Epoch: 0, iter: 7/782, Loss 2.66, NLL: 2.66, RegTerm: 0.0, GradNorm: 6.9
Epoch: 0, iter: 8/782, Loss 2.70, NLL: 2.70, RegTerm: 0.0, GradNorm: 4.4
Epoch: 0, iter: 9/782, Loss 2.76, NLL: 2.76, RegTerm: 0.0, GradNorm: 8.7
Epoch: 0, iter: 10/782, Loss 2.74, NLL: 2.74, RegTerm: 0.0, GradNorm: 9.3
Epoch: 0, iter: 11/782, Loss 2.73, NLL: 2.73, RegTerm: 0.0, GradNorm: 10.6
Epoch: 0, iter: 12/782, Loss 2.72, NLL: 2.72, RegTerm: 0.0, GradNorm: 7.1
Epoch: 0, iter: 13/782, Loss 2.66, NLL: 

: 

: 