In [34]:
from ase.db import connect
from matplotlib import pyplot as plt
import math
import numpy as np
import os
import pandas as pd
import pickle
import sys

sys.path.append("..")
sys.path.append("../analyses")

from analyses import generate_plots
from analyses import check_distances

In [42]:
def normalize_histogram(hist):
    hist = np.array(hist)
    prob = hist / np.sum(hist)
    return prob


def kl_divergence(p1, p2):
    return np.sum(p1*np.log(p1 / p2))


def js_divergence(h1, h2):
    p1 = normalize_histogram(h1) + 1e-10
    p2 = normalize_histogram(h2) + 1e-10
    M = (p1 + p2)/2
    js = (kl_divergence(p1, M) + kl_divergence(p2, M)) / 2
    return js

version = 6
model = 'nequip'
interactions = 4
l = 5

step = 'best'
init = 'C6H5'

stats_list = []

element_numbers = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
max_dist = 15  # max interatomic distance, for histogram purposes
bin_size = 0.5

distances_qm9 = [0 for _ in np.arange(max_dist / bin_size)]
with connect('../qm9_data/qm9-all.db') as conn:
    for row in conn.select():
        for dist in check_distances.get_interatomic_distances(row.toatoms().positions):
            distances_qm9[math.floor(dist / bin_size)] += 1

In [44]:
for channels in [32, 64, 128]:
    for beta in [1, 10, 100, 1000]:

        stats_file = os.path.join(
            '../analyses/molecules/generated/',
            f'v{version}',
            model,
            f'interactions={interactions}',
            f'l={l}',
            f'channels={channels}',
            f'beta={beta:.1f}',
            f'step={step}',
            f'generated_molecules_init={init}_statistics.pkl'
        )
        mol_path = os.path.join(
            '../analyses/molecules/generated/',
            f'v{version}',
            model,
            f'interactions={interactions}',
            f'l={l}',
            f'channels={channels}',
            f'beta={beta:.1f}',
            f'step={step}',
            f'generated_molecules_init={init}.db'
        )

        with open(stats_file, 'rb') as f:
            stats = pickle.load(f)
        stats_gen = stats['generated_stats']

        molecules = []
        with connect(mol_path) as conn:
            for row in conn.select():
                molecules.append(row.toatoms())

        # number of valid mols
        cond_valid = stats_gen['valid_mol'] == 1
        cond_unique = stats_gen['duplicating'] == -1
        num_unique = 64 - stats_gen[~cond_unique].shape[0]
        num_valid = stats_gen[cond_valid].shape[0]
        num_valid_unique = stats_gen[cond_valid & cond_unique].shape[0]

        # atom distributions
        atom_type_counts = {'H': 0, 'C': 0, 'N': 0, 'O': 0, 'F': 0}
        for mol in molecules:
            for atom in mol.numbers:
                atom_type_counts[element_numbers[atom]] += 1

        # distance distributions - buckets every 0.5 A, distances from 0 to 15 A
        distances = [0 for _ in np.arange(max_dist / bin_size)]
        for mol in molecules:
            for distance in check_distances.get_interatomic_distances(mol.get_positions()):
                distances[math.floor(distance / bin_size)] += 1

        stats_list.append([
            # model info
            model,
            version,
            channels,
            beta,
            init,
            # stats
            num_unique,  # number of unique molecules
            num_valid,  # number of valid molecules (including duplicates)
            num_valid_unique,  # number of unique and valid molecules
            num_unique / 64,  # fraction of mols that are unique
            num_valid / 64,  # fraction of mols that are unique
            num_valid_unique / 64,  # fraction of mols that are valid & unique
            stats_gen['valid_atoms'].sum() / stats_gen['n_atoms'].sum(),  # fraction of atoms that are valid
            atom_type_counts['H'],
            atom_type_counts['C'],
            atom_type_counts['N'],
            atom_type_counts['O'],
            atom_type_counts['F'],
            js_divergence(distances, distances_qm9),  # Jensen-Shannon divergence btwn distance distributions
            *distances
        ])


In [46]:
distance_labels = [f'{i/2:.1f}-{(i+1)/2:.1f}' for i in range(2*max_dist)]
model_stats = pd.DataFrame(stats_list, columns=[
    'model',
    'version',
    'channels', 
    'beta',
    'init',
    'num_unique',
    'num_valid',
    'num_valid_unique',
    'frac_unique',
    'frac_valid',
    'frac_valid_unique',
    'frac_valid_atoms',
    'H',
    'C',
    'N',
    'O',
    'F',
    'js_divergence',
    *distance_labels
])

In [48]:
model_stats['js_divergence']

0     110.000000
1      47.000000
2       5.000000
3       0.000000
4      52.000000
5       7.000000
6       1.000000
7       0.000000
8     178.000000
9     103.000000
10    113.000000
11    128.000000
12      0.011553
13      0.009392
14      0.019955
15      0.030112
16      0.011799
17      0.012066
18      0.011955
19      0.022409
20      0.016590
21      0.010982
22      0.017508
23      0.027140
Name: js_divergence, dtype: float64