# Exploration of **Conditional Sampling** EDM: E(3) Equivariant Diffusion Model for Molecule Generation in 3D.
[Github repo](https://github.com/ehoogeboom/e3_diffusion_for_molecules)

- __Quantitative__: 
    - `main_quantitative`
    - Different options for task: `edm, qm9_second_half, naive`
    - Calls `qm9.property_prediction.main_qm9_prop.test` $\to$ `train` but on `test` dataset
- __Qualitative__: `main_qualitative`
    - Calls `save_and_sample_conditional` $\to$ `qm9.sampling.sample_sweep_conditional`

## 1. Arguments
This arguments are not relevant because the arguments used to train the model are going to be loaded to sample.

In [1]:
import argparse, torch

parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default='exploration_luisa')
#parser.add_argument('--generators_path', type=str, default='outputs/exp_cond_alpha_pretrained') #original
parser.add_argument('--generators_path', type=str, default='outputs/exp_35_conditional_nf192_9l_alpha')
parser.add_argument('--classifiers_path', type=str, default='qm9/property_prediction/outputs/exp_class_alpha_pretrained')
parser.add_argument('--property', type=str, default='mu',
                    help="'alpha', 'homo', 'lumo', 'gap', 'mu', 'Cv'")
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--debug_break', type=eval, default=False,
                    help='break point or not')
parser.add_argument('--log_interval', type=int, default=5,
                    help='break point or not')
parser.add_argument('--batch_size', type=int, default=1,
                    help='break point or not')
parser.add_argument('--iterations', type=int, default=20,
                    help='break point or not')
parser.add_argument('--task', type=str, default='qualitative',
                    help='naive, edm, qm9_second_half, qualitative')
parser.add_argument('--n_sweeps', type=int, default=1,
                    help='number of sweeps for the qualitative conditional experiment')
args = parser.parse_args(args=[])
device = torch.device("cpu")
args.device = device

## 2. Qualitative
We use a pre-trained model that was saved to `outputs/exp_35_conditional_nf192_9l_alpha`. We are going to discover that such a model was trained to condition on `alpha` exclusively.

In [2]:
from eval_conditional_qm9 import get_args_gen, get_dataloader, get_generator

# get the args used for training the generative model
args_gen = get_args_gen(args.generators_path)

In [3]:
# exploration of arguments used to train the model
print(args_gen)
print(args_gen.conditioning) # this means that the model was trained to condition only alpha
print(args_gen.context_node_nf) # only one because alpha is a global variable (1 scalar per molecule)

Namespace(exp_name='exp_35_conditional_nf192_9l_alpha', model='egnn_dynamics', probabilistic_model='diffusion', diffusion_steps=1000, diffusion_noise_schedule='polynomial_2', diffusion_noise_precision=1e-05, diffusion_loss_type='l2', n_epochs=3000, batch_size=64, lr=0.0001, brute_force=False, actnorm=True, break_train_epoch=False, dp=True, condition_time=True, clip_grad=True, trace='hutch', n_layers=9, inv_sublayers=1, nf=192, tanh=True, attention=True, norm_constant=1, sin_embedding=False, ode_regularization=0.001, dataset='qm9_second_half', datadir='qm9/temp', filter_n_atoms=None, dequantization='deterministic', n_report_steps=1, wandb_usr='vgsatorras', no_wandb=False, online=True, no_cuda=False, save_model=True, generate_epochs=1, num_workers=0, test_epochs=10, data_augmentation=False, conditioning=['alpha'], resume=None, start_epoch=0, ema_decay=0.999, augment_noise=0, n_stability_samples=500, normalize_factors=[1, 8, 1], remove_h=False, include_charges=False, visualize_every_batch

In [4]:
from qm9.utils import compute_mean_mad 
dataloaders = get_dataloader(args_gen)
property_norms = compute_mean_mad(dataloaders, args_gen.conditioning, args_gen.dataset)
model, nodes_dist, prop_dist, dataset_info = get_generator(args.generators_path, dataloaders, 
                                                           args.device, args_gen, property_norms)

Entropy of n_nodes: H[N] -2.4754221439361572


  probs = Categorical(torch.tensor(probs))


alphas2 [9.99990000e-01 9.99988000e-01 9.99982000e-01 ... 2.59676966e-05
 1.39959211e-05 1.00039959e-05]
gamma [-11.51291546 -11.33059532 -10.92513058 ...  10.55863126  11.17673063
  11.51251595]


In [5]:
# exploration of the model
model.T # timesteps 

1000

In [14]:
prop_dist.distributions.keys() # only alpha because the trained model was conditioning only on this property
prop_dist.distributions

{'alpha': {3: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(12.9900), tensor(12.9900)]},
  4: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(9.4600), tensor(27.7000)]},
  5: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(13.2100), tensor(32.6600)]},
  6: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(21.5700), tensor(38.5200)]},
  7: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(24.0400), tensor(43.8600)]},
  8: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(23.9500), tensor(72.3900)]},
  9: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(28.1300), tensor(88.8500)]},
  10: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(34.5600), tensor(100.8600)]},
  11: {'probs': Categorical(probs: torch.Size([1000])),
   'params': [tensor(34.7500), tensor(92.3400)]},
  12: {'probs': Categorical(probs: torch.Siz

For the sampling we have to define `n_nodes`: number of atoms in the generated molecule, and `n_frames`: number of values of `alpha` for which we are going to generate a molecule.

We generate `n_frames` molecules with `n_frames` different values of `alpha` equally spaced between `min_alpha` and `max_alpha`.
In this case all molecules generated have `n_nodes` atoms.

In [11]:
from qm9.sampling import sample_sweep_conditional

one_hot, charges, x, node_mask = sample_sweep_conditional(args_gen, device, model, dataset_info, prop_dist, n_nodes=10, n_frames=5)

In [12]:
# save generated molecule to xyz file
import qm9.visualizer as vis

epoch = 0
id_from = 1

vis.save_xyz_file(
        'outputs/%s/analysis/run%s/' % (args.exp_name, epoch), one_hot, charges, x, dataset_info,
        id_from, name='conditional', node_mask=node_mask)

In [13]:
# visualize the `n_frames` different molecules with increasing value of `alpha`.
vis.visualize_chain("outputs/%s/analysis/run%s/" % (args.exp_name, epoch), dataset_info,
                        wandb=None, mode='conditional', spheres_3d=True)

Creating gif with 5 images


## 2.1 Qualitative modified to generate a molecule for a specific value of `alpha`

In [22]:
from qm9.sampling import sample

n_nodes = 10 # number of nodes/atoms
value_of_alpha = 40

nodesxsample = torch.tensor([n_nodes])
# normalize the value of alpha
mean, mad = prop_dist.normalizer['alpha']['mean'], prop_dist.normalizer['alpha']['mad']
value_of_alpha = (value_of_alpha - mean) / mad

# put context in right format
context = []
context_row = torch.tensor([value_of_alpha]).unsqueeze(1)
context.append(context_row)
context = torch.cat(context, dim=1).float().to(device)


In [23]:
context

tensor([[0.7376]])

In [24]:
one_hot, charges, x, node_mask = sample(args_gen, device, model, dataset_info, prop_dist, nodesxsample=nodesxsample, context=context, fix_noise=True)

In [25]:
# save generated molecule to xyz file
import qm9.visualizer as vis

epoch = 0
id_from = 1

vis.save_xyz_file(
        'outputs/%s/analysis/run%s/' % (args.exp_name, epoch), one_hot, charges, x, dataset_info,
        id_from, name='conditional', node_mask=node_mask)

In [26]:
vis.visualize_chain("outputs/%s/analysis/run%s/" % (args.exp_name, epoch), dataset_info,
                        wandb=None, mode='conditional', spheres_3d=True)

Creating gif with 1 images


## 3. Quantitative