# Exploration of **Sampling** EDM: E(3) Equivariant Diffusion Model for Molecule Generation in 3D.
[Github repo](https://github.com/ehoogeboom/e3_diffusion_for_molecules)

## 1. Main file `eval_sample.py`
The evaluation and **visualization** is triggered with a command like:
```bash
python eval_sample.py --model_path outputs/edm_qm9 --n_samples 10_000
```
The evaluation and **analysis** is triggered with a command like:
```bash
python eval_analyze.py --model_path outputs/edm_qm9 --n_samples 10_000
```

### 1. Arguments
_Note:_ I modified the defaults to match the ones given for the example of edm_qm9

In [1]:
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default='edm_qm9')
parser.add_argument('--model', type=str, default='egnn_dynamics',
                    help='our_dynamics | schnet | simple_dynamics | '
                         'kernel_dynamics | egnn_dynamics |gnn_dynamics')
parser.add_argument('--probabilistic_model', type=str, default='diffusion',
                    help='diffusion')
# Training complexity is O(1) (unaffected), but sampling complexity is O(steps).
parser.add_argument('--diffusion_steps', type=int, default=1000)
parser.add_argument('--diffusion_noise_schedule', type=str, default='polynomial_2',
                    help='learned, cosine')
parser.add_argument('--diffusion_noise_precision', type=float, default=1e-5,
                    )
parser.add_argument('--diffusion_loss_type', type=str, default='l2',
                    help='vlb, l2')

parser.add_argument('--n_epochs', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--brute_force', type=eval, default=False,
                    help='True | False')
parser.add_argument('--actnorm', type=eval, default=True,
                    help='True | False')
parser.add_argument('--break_train_epoch', type=eval, default=False,
                    help='True | False')
parser.add_argument('--dp', type=eval, default=True,
                    help='True | False')
parser.add_argument('--condition_time', type=eval, default=True,
                    help='True | False')
parser.add_argument('--clip_grad', type=eval, default=True,
                    help='True | False')
parser.add_argument('--trace', type=str, default='hutch',
                    help='hutch | exact')
# EGNN args -->
parser.add_argument('--n_layers', type=int, default=9,
                    help='number of layers')
parser.add_argument('--inv_sublayers', type=int, default=1,
                    help='number of layers')
parser.add_argument('--nf', type=int, default=256,
                    help='number of layers')
parser.add_argument('--tanh', type=eval, default=True,
                    help='use tanh in the coord_mlp')
parser.add_argument('--attention', type=eval, default=True,
                    help='use attention in the EGNN')
parser.add_argument('--norm_constant', type=float, default=1,
                    help='diff/(|diff| + norm_constant)')
parser.add_argument('--sin_embedding', type=eval, default=False,
                    help='whether using or not the sin embedding')
# <-- EGNN args
parser.add_argument('--ode_regularization', type=float, default=1e-3)
parser.add_argument('--dataset', type=str, default='qm9',
                    help='qm9 | qm9_second_half (train only on the last 50K samples of the training dataset)')
parser.add_argument('--datadir', type=str, default='qm9/temp',
                    help='qm9 directory')
parser.add_argument('--filter_n_atoms', type=int, default=None,
                    help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
parser.add_argument('--dequantization', type=str, default='argmax_variational',
                    help='uniform | variational | argmax_variational | deterministic')
parser.add_argument('--n_report_steps', type=int, default=1)
parser.add_argument('--wandb_usr', type=str)
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--save_model', type=eval, default=True,
                    help='save model')
parser.add_argument('--generate_epochs', type=int, default=1,
                    help='save model')
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
parser.add_argument('--test_epochs', type=int, default=20)
parser.add_argument('--data_augmentation', type=eval, default=False, help='use attention in the EGNN')
parser.add_argument("--conditioning", nargs='+', default=[],
                    help='arguments : homo | lumo | alpha | gap | mu | Cv' )
parser.add_argument('--resume', type=str, default=None,
                    help='')
parser.add_argument('--start_epoch', type=int, default=0,
                    help='')
parser.add_argument('--ema_decay', type=float, default=0.999,
                    help='Amount of EMA decay, 0 means off. A reasonable value is 0.999.')
parser.add_argument('--augment_noise', type=float, default=0)
parser.add_argument('--n_stability_samples', type=int, default=1000,
                    help='Number of samples to compute the stability')
parser.add_argument('--normalize_factors', type=eval, default=[1, 4, 10],
                    help='normalize factors for [x, categorical, integer]')
parser.add_argument('--remove_h', action='store_true')
parser.add_argument('--include_charges', type=eval, default=True,
                    help='include atom charge or not')
parser.add_argument('--visualize_every_batch', type=int, default=1e8,
                    help="Can be used to visualize multiple times per epoch")
parser.add_argument('--normalization_factor', type=float, default=1,
                    help="Normalize the sum aggregation of EGNN")
parser.add_argument('--aggregation_method', type=str, default='sum',
                    help='"sum" or "mean"')
# Specific of eval_sample
parser.add_argument('--model_path', type=str, default="outputs/edm_qm9", help='Specify model path')
parser.add_argument('--n_tries', type=int, default=10, help='N tries to find stable molecule for gif animation')
parser.add_argument('--n_nodes', type=int, default=19, help='number of atoms in molecule for gif animation')
args = parser.parse_args(args=[])

args.context_node_nf = 0
device = torch.device("cpu")

### 2. Dataset

In [2]:
from configs.datasets_config import get_dataset_info
from qm9 import dataset

dataset_info = get_dataset_info(args.dataset, args.remove_h)

  if line[0] is '#':
  if line_counter is 0:
  elif line_counter is 1:
  if len(split) is 4:


In [3]:
# Retrieve QM9 dataloaders
dataloaders, charge_scale = dataset.retrieve_dataloaders(args)

dict_keys([0, 1, 6, 7, 8, 9])
dict_keys([0, 1, 6, 7, 8, 9])
dict_keys([0, 1, 6, 7, 8, 9])


### 3. EGNN model

In [4]:
from qm9.models import get_model
import torch

In [5]:
flow, nodes_dist, prop_dist = get_model(args, device, dataset_info, dataloaders['train'])

Entropy of n_nodes: H[N] -2.475700616836548
alphas2 [9.99990000e-01 9.99988000e-01 9.99982000e-01 ... 2.59676966e-05
 1.39959211e-05 1.00039959e-05]
gamma [-11.51291546 -11.33059532 -10.92513058 ...  10.55863126  11.17673063
  11.51251595]


In [6]:
flow

EnVariationalDiffusion(
  (gamma): PredefinedNoiseSchedule()
  (dynamics): EGNN_dynamics_QM9(
    (egnn): EGNN(
      (embedding): Linear(in_features=7, out_features=256, bias=True)
      (embedding_out): Linear(in_features=256, out_features=7, bias=True)
      (e_block_0): EquivariantBlock(
        (gcl_0): GCL(
          (edge_mlp): Sequential(
            (0): Linear(in_features=514, out_features=256, bias=True)
            (1): SiLU()
            (2): Linear(in_features=256, out_features=256, bias=True)
            (3): SiLU()
          )
          (node_mlp): Sequential(
            (0): Linear(in_features=512, out_features=256, bias=True)
            (1): SiLU()
            (2): Linear(in_features=256, out_features=256, bias=True)
          )
          (att_mlp): Sequential(
            (0): Linear(in_features=256, out_features=1, bias=True)
            (1): Sigmoid()
          )
        )
        (gcl_equiv): EquivariantUpdate(
          (coord_mlp): Sequential(
            (0):

### 4. State_dict
In the repo there are some pre-trained models localed in `outputs`.
Here we are retrieving .npy files containing information about the trained model: weights and biases for instance. 

In [10]:
from os.path import join
fn = 'generative_model_ema.npy' if args.ema_decay > 0 else 'generative_model.npy'
flow_state_dict = torch.load(join(args.model_path, fn), map_location=device)
print(flow_state_dict.keys())

odict_keys(['buffer', 'gamma.gamma', 'dynamics.egnn.embedding.weight', 'dynamics.egnn.embedding.bias', 'dynamics.egnn.embedding_out.weight', 'dynamics.egnn.embedding_out.bias', 'dynamics.egnn.e_block_0.gcl_0.edge_mlp.0.weight', 'dynamics.egnn.e_block_0.gcl_0.edge_mlp.0.bias', 'dynamics.egnn.e_block_0.gcl_0.edge_mlp.2.weight', 'dynamics.egnn.e_block_0.gcl_0.edge_mlp.2.bias', 'dynamics.egnn.e_block_0.gcl_0.node_mlp.0.weight', 'dynamics.egnn.e_block_0.gcl_0.node_mlp.0.bias', 'dynamics.egnn.e_block_0.gcl_0.node_mlp.2.weight', 'dynamics.egnn.e_block_0.gcl_0.node_mlp.2.bias', 'dynamics.egnn.e_block_0.gcl_0.att_mlp.0.weight', 'dynamics.egnn.e_block_0.gcl_0.att_mlp.0.bias', 'dynamics.egnn.e_block_0.gcl_equiv.coord_mlp.0.weight', 'dynamics.egnn.e_block_0.gcl_equiv.coord_mlp.0.bias', 'dynamics.egnn.e_block_0.gcl_equiv.coord_mlp.2.weight', 'dynamics.egnn.e_block_0.gcl_equiv.coord_mlp.2.bias', 'dynamics.egnn.e_block_0.gcl_equiv.coord_mlp.4.weight', 'dynamics.egnn.e_block_1.gcl_0.edge_mlp.0.weight'

In [11]:
# Now we apply those weights and biases to the model flow
flow.load_state_dict(flow_state_dict)

<All keys matched successfully>

### 5. Sampling

#### 5.1 Sampling handful of molecules
Different sizes

In [41]:
from qm9.sampling import sample

n_samples = 30
nodesxsample = nodes_dist.sample(n_samples)
one_hot, charges, x, node_mask = sample(args, device, flow, dataset_info, nodesxsample=nodesxsample)

In [48]:
print(charges.shape)
print(x.shape)
print(one_hot.shape)
print(node_mask.shape)

torch.Size([30, 29, 1])
torch.Size([30, 29, 3])
torch.Size([30, 29, 5])
torch.Size([30, 29, 1])


In [42]:
import qm9.visualizer as vis
# save .txt files (1 per molecule) containing the coordinates of the atoms and their type 
vis.save_xyz_file(
        join(args.model_path, 'eval/molecules/'), one_hot, charges, x,
        id_from=0, name='molecule', dataset_info=dataset_info,
        node_mask=node_mask)

#### 5.2 Sampling stable molecules

In [49]:
n_tries = 20
n_samples = 10

In [50]:
from qm9.analyze import check_stability

nodesxsample = nodes_dist.sample(n_tries)
one_hot, charges, x, node_mask = sample(args, device, flow, dataset_info, nodesxsample=nodesxsample)
counter = 0

for i in range(n_tries):
    num_atoms = int(node_mask[i:i+1].sum().item())
    atom_type = one_hot[i:i+1, :num_atoms].argmax(2).squeeze(0).cpu().detach().numpy()
    x_squeeze = x[i:i+1, :num_atoms].squeeze(0).cpu().detach().numpy()
    mol_stable = check_stability(x_squeeze, atom_type, dataset_info)[0]

    num_remaining_attempts = n_tries - i - 1
    num_remaining_samples = n_samples - counter

    if mol_stable or num_remaining_attempts <= num_remaining_samples:
        if mol_stable:
            print('Found stable mol.')
        vis.save_xyz_file(
            join(args.model_path, 'eval/molecules/'),
            one_hot[i:i+1], charges[i:i+1], x[i:i+1],
            id_from=counter, name='molecule_stable',
            dataset_info=dataset_info,
            node_mask=node_mask[i:i+1])
        counter += 1

        if counter >= n_samples: break

Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.


### 6. Visualization
Uses rdkit to generate .png from .txt files. 

In [51]:
vis.visualize(join(args.model_path, 'eval/molecules/'), dataset_info, max_num=100, spheres_3d=True)

Average distance between atoms 3.2808544635772705
Average distance between atoms 2.8001558780670166
Average distance between atoms 3.092637062072754
Average distance between atoms 2.813894033432007
Average distance between atoms 2.7666893005371094
Average distance between atoms 3.466494560241699
Average distance between atoms 3.187229633331299
Average distance between atoms 3.276587724685669
Average distance between atoms 3.6992554664611816
Average distance between atoms 3.2033517360687256
Average distance between atoms 3.9616551399230957
Average distance between atoms 3.4325289726257324
Average distance between atoms 3.205986499786377
Average distance between atoms 3.283141851425171
Average distance between atoms 3.0159482955932617
Average distance between atoms 3.1348373889923096
Average distance between atoms 3.3704986572265625
Average distance between atoms 3.2689449787139893
Average distance between atoms 3.0262720584869385
Average distance between atoms 3.088954210281372
Average 

## 2. Sampling visualization chain

In [13]:
from qm9.sampling import sample_chain
import qm9.visualizer as vis

num_chains=100
id_from=0

for i in range(num_chains):
    target_path = f'eval/chain_{i}/'
    # sample
    one_hot, charges, x = sample_chain(args, device, flow, args.n_tries, dataset_info)
    # save molecules to .txt file
    vis.save_xyz_file(join(args.model_path, target_path), one_hot, charges, x, dataset_info, id_from, name='chain')
    # visualize generated chains
    vis.visualize_chain_uncertainty(join(args.model_path, target_path), dataset_info, spheres_3d=True)

Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)
Creating gif with 108 images
Found stable molecule to visualize :)


KeyboardInterrupt: 