In [97]:
from ase.db import connect
from matplotlib import pyplot as plt
import math
import numpy as np
import os
import pandas as pd
import pickle
from plotly import express as px
from plotly import graph_objects as go
import sys

sys.path.append("..")
sys.path.append("../analyses")

from analyses import check_distances

In [158]:
interactions = 4
l = 5

init = 'C6H5'

params = {
    6: {
        'nequip': {'channels': [32, 64, 128], 'steps': ['best', 'best', 'best']},
        'nequip-l2': {'channels': [32, 64, 128], 'steps': [885000, 675000, 375000]},
    },
    7: {
        'nequip-l2': {'channels': [32], 'steps': ['best']},
    }
}

stats_list = []
element_numbers = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
max_dist = 15  # max interatomic distance for histogram, in angstroms
bin_size = 0.5  # bin size for histogram, in angstroms

In [159]:
def normalize_histogram(hist):
    hist = np.array(hist)
    prob = hist / np.sum(hist)
    return prob


def kl_divergence(p1, p2):
    return np.sum(p1*np.log(p1 / p2))


def js_divergence(h1, h2):
    p1 = normalize_histogram(h1) + 1e-10
    p2 = normalize_histogram(h2) + 1e-10
    M = (p1 + p2)/2
    js = (kl_divergence(p1, M) + kl_divergence(p2, M)) / 2
    return js

elements_qm9 = {'H': 0, 'C': 0, 'N': 0, 'O': 0, 'F': 0}
distances_qm9 = [0 for _ in np.arange(max_dist / bin_size)]
with connect('../qm9_data/qm9-all.db') as conn:
    for row in conn.select():
        mol = row.toatoms()
        for atom in mol.numbers:
            elements_qm9[element_numbers[atom]] += 1
        for dist in check_distances.get_interatomic_distances(mol.positions):
            distances_qm9[math.floor(dist / bin_size)] += 1

atom_labels = ['H', 'C', 'N', 'O', 'F']
distance_labels = [i*bin_size for i in np.arange(max_dist/bin_size)]
# distance_labels = [f'{i*bin_size}-{(i+1)*bin_size}' for i in np.arange(max_dist/bin_size)]
col_labels = [
    'model',
    'version',
    'channels', 
    'beta',
    'init',
    'num_unique',
    'num_valid',
    'num_valid_unique',
    'frac_unique',
    'frac_valid',
    'frac_valid_unique',
    'frac_valid_atoms',
    'js_divergence',
    *atom_labels,
    *distance_labels
]

In [160]:
for version in [6, 7]:
    for model in params[version]:
        for channels, step in zip(params[version][model]['channels'], params[version][model]['steps']):
            for beta in [1, 10, 100, 1000]:

                stats_file = os.path.join(
                    '../analyses/molecules/generated/',
                    f'v{version}',
                    model,
                    f'interactions={interactions}',
                    f'l={l}',
                    f'channels={channels}',
                    f'beta={beta:.1f}',
                    f'step={step}',
                    f'generated_molecules_init={init}_statistics.pkl'
                )
                mol_path = os.path.join(
                    '../analyses/molecules/generated/',
                    f'v{version}',
                    model,
                    f'interactions={interactions}',
                    f'l={l}',
                    f'channels={channels}',
                    f'beta={beta:.1f}',
                    f'step={step}',
                    f'generated_molecules_init={init}.db'
                )

                with open(stats_file, 'rb') as f:
                    stats = pickle.load(f)
                stats_gen = stats['generated_stats']
                stats_gen = stats_gen[stats_gen['duplicating']==-1]

                molecules = []
                with connect(mol_path) as conn:
                    for row in conn.select():
                        molecules.append(row.toatoms())

                # number of valid mols
                cond_valid = stats_gen['valid_mol'] == 1
                cond_unique = stats_gen['duplicating'] == -1
                num_unique = stats_gen.shape[0]
                # num_unique = 64 - stats_gen[~cond_unique].shape[0]
                num_valid = stats_gen[cond_valid].shape[0]
                num_valid_unique = stats_gen[cond_valid & cond_unique].shape[0]

                # atom distributions
                atom_type_counts = {'H': 0, 'C': 0, 'N': 0, 'O': 0, 'F': 0}
                for mol in molecules:
                    for atom in mol.numbers:
                        atom_type_counts[element_numbers[atom]] += 1

                # distance distributions - buckets every 0.5 A, distances from 0 to 15 A
                distances = [0 for _ in np.arange(max_dist / bin_size)]
                for mol in molecules:
                    for distance in check_distances.get_interatomic_distances(mol.get_positions()):
                        distances[math.floor(distance / bin_size)] += 1

                stats_list.append([
                    # model info
                    model,
                    version,
                    channels,
                    beta,
                    init,
                    # stats
                    num_unique,  # number of unique molecules
                    num_valid,  # number of valid molecules (including duplicates)
                    num_valid_unique,  # number of unique and valid molecules
                    num_unique / 64,  # fraction of mols that are unique
                    num_valid / 64,  # fraction of mols that are unique
                    num_valid_unique / 64,  # fraction of mols that are valid & unique
                    stats_gen['valid_atoms'].sum() / stats_gen['n_atoms'].sum(),  # fraction of atoms that are valid
                    js_divergence(distances, distances_qm9),  # Jensen-Shannon divergence btwn distance distributions
                    atom_type_counts['H'],
                    atom_type_counts['C'],
                    atom_type_counts['N'],
                    atom_type_counts['O'],
                    atom_type_counts['F'],
                    *distances
                ])


In [161]:
model_stats = pd.DataFrame(stats_list, columns=col_labels).set_index(['model', 'version', 'channels', 'beta', 'init'])
model_stats.to_csv('generated_molecules_statistics.csv')
model_stats.to_pickle('generated_molecules_statistics.pkl')

In [162]:
name = 'nequip'
version = 6
channels = 128
beta = 100
init = 'C6H5'

selected_model = (name, version, channels, beta, init)

In [163]:
fig = go.Figure(data=[
    go.Bar(
        x=atom_labels,
        y=normalize_histogram(list(elements_qm9.values())), 
        name='QM9'
    ),
    go.Bar(
        x=atom_labels,
        y=normalize_histogram(list(model_stats.loc[selected_model, atom_labels])),
        name='Model'
    )
])
# Change the bar mode
fig.update_layout(barmode='group')
fig.show()

In [164]:
fig = go.Figure(data=[
    go.Bar(
        x=distance_labels,
        y=normalize_histogram(distances_qm9),
        name='QM9'
    ),
    go.Bar(
        x=distance_labels,
        y=normalize_histogram(list(model_stats.loc[selected_model, distance_labels])),
        name='Model'
    )
])
# Change the bar mode
fig.update_layout(barmode='group')
fig.show()