In [95]:
from ase.db import connect
import itertools
from matplotlib import pyplot as plt
import math
import numpy as np
import os
import pandas as pd
import pickle
from plotly import express as px
from plotly import graph_objects as go
import sys

sys.path.append("..")
sys.path.append("../analyses")

from analyses import analysis
from analyses import check_distances

In [96]:
interactions = 4
l = 5

params = {
    6: {
        'nequip': {'channels': [32, 64, 128], 'steps': ['best', 'best', 'best']},
        'nequip-l2': {'channels': [32, 64, 128], 'steps': [885000, 675000, 375000]},
    },
    7: {
        'nequip-l2': {'channels': [32], 'steps': ['best']},
    }
}

element_numbers = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
max_dist = 15  # max interatomic distance for histogram, in angstroms
bin_size = 0.5  # bin size for histogram, in angstroms

In [120]:
def normalize_histogram(hist):
    hist = np.array(hist)
    prob = hist / np.sum(hist)
    return prob


def kl_divergence(p1, p2):
    return np.sum(p1*np.log(p1 / p2))


def js_divergence(h1, h2):
    p1 = normalize_histogram(h1) + 1e-10
    p2 = normalize_histogram(h2) + 1e-10
    M = (p1 + p2)/2
    js = (kl_divergence(p1, M) + kl_divergence(p2, M)) / 2
    return js

elements_qm9 = {'H': 0, 'C': 0, 'N': 0, 'O': 0, 'F': 0}
distances_qm9 = [0 for _ in np.arange(max_dist / bin_size)]
distances_CC_qm9 = [0 for _ in np.arange(max_dist / bin_size)]
distances_CH_qm9 = [0 for _ in np.arange(max_dist / bin_size)]
distances_CO_qm9 = [0 for _ in np.arange(max_dist / bin_size)]
# get atom/distance counts for QM9
with connect('../qm9_data/qm9-all.db') as conn:
    for row in conn.select():
        mol = row.toatoms()
        for atom in mol.numbers:
            elements_qm9[element_numbers[atom]] += 1
        for i, j in itertools.combinations(range(len(mol.numbers)), 2):
            dist = np.linalg.norm(mol.positions[i] - mol.positions[j])
            distances_qm9[math.floor(dist / bin_size)] += 1
            if mol.numbers[i] == 6 and mol.numbers[j] == 6:
                distances_CC_qm9[math.floor(dist / bin_size)] += 1
            if (mol.numbers[i] == 6 and mol.numbers[j] == 1) or (mol.numbers[i] == 1 and mol.numbers[j] == 6):
                distances_CH_qm9[math.floor(dist / bin_size)] += 1
            if (mol.numbers[i] == 6 and mol.numbers[j] == 8) or (mol.numbers[i] == 8 and mol.numbers[j] == 6):
                distances_CO_qm9[math.floor(dist / bin_size)] += 1


atom_labels = ['H', 'C', 'N', 'O', 'F']
distance_labels = [i*bin_size for i in np.arange(max_dist/bin_size)]
distance_CC_labels = [f'CC_{i*bin_size}' for i in np.arange(max_dist/bin_size)]
distance_CH_labels = [f'CH_{i*bin_size}' for i in np.arange(max_dist/bin_size)]
distance_CO_labels = [f'CO_{i*bin_size}' for i in np.arange(max_dist/bin_size)]

col_labels = [
    'model',
    'version',
    'channels', 
    'beta',
    'init',
    'num_unique',
    'num_known',
    'num_valid',
    'num_valid_unique',
    'frac_unique',
    'frac_valid',
    'frac_valid_unique',
    'frac_valid_atoms',
    'frac_valid_H',
    'frac_valid_C',
    'frac_valid_N',
    'frac_valid_O',
    'frac_valid_F',
    'js_divergence',
    *atom_labels,
    *distance_labels,
    *distance_CC_labels,
    *distance_CH_labels,
    *distance_CO_labels,
]

In [121]:
stats_list = []

for version in [6, 7]:
    for model in params[version]:
        for channels, step in zip(params[version][model]['channels'], params[version][model]['steps']):
            for beta in [1, 10, 100, 1000]:
                for init in ['CH3', 'C6H5']:
                    stats_file = os.path.join(
                        '../analyses/molecules/generated/',
                        f'v{version}',
                        model,
                        f'interactions={interactions}',
                        f'l={l}',
                        f'channels={channels}',
                        f'beta={beta:.1f}',
                        f'step={step}',
                        f'generated_molecules_init={init}_statistics.pkl'
                    )
                    mol_path = os.path.join(
                        '../analyses/molecules/generated/',
                        f'v{version}',
                        model,
                        f'interactions={interactions}',
                        f'l={l}',
                        f'channels={channels}',
                        f'beta={beta:.1f}',
                        f'step={step}',
                        f'generated_molecules_init={init}.db'
                    )

                    with open(stats_file, 'rb') as f:
                        stats = pickle.load(f)
                    stats_gen = stats['generated_stats']
                    stats_gen = stats_gen[stats_gen['duplicating']==-1]

                    molecules = []
                    with connect(mol_path) as conn:
                        for row in conn.select():
                            molecules.append(row.toatoms())

                    # number of valid mols
                    cond_valid = stats_gen['valid_mol'] == 1
                    cond_unique = stats_gen['duplicating'] == -1
                    num_unique = stats_gen.shape[0]
                    num_known = stats_gen['known'].sum()
                    num_valid = stats_gen[cond_valid].shape[0]
                    num_valid_unique = stats_gen[cond_valid & cond_unique].shape[0]

                    # atom distributions
                    atom_type_counts = {'H': 0, 'C': 0, 'N': 0, 'O': 0, 'F': 0}
                    for mol in molecules:
                        for atom in mol.numbers:
                            atom_type_counts[element_numbers[atom]] += 1

                    # distance distributions - buckets every 0.5 A, distances from 0 to 15 A
                    distances = [0 for _ in np.arange(max_dist / bin_size)]
                    distances_CC = [0 for _ in np.arange(max_dist / bin_size)]
                    distances_CH = [0 for _ in np.arange(max_dist / bin_size)]
                    distances_CO = [0 for _ in np.arange(max_dist / bin_size)]
                    for mol in molecules:
                        for i, j in itertools.combinations(range(len(mol.numbers)), 2):
                            dist = np.linalg.norm(mol.positions[i] - mol.positions[j])
                            distances[math.floor(dist / bin_size)] += 1
                            if mol.numbers[i] == 6 and mol.numbers[j] == 6:
                                distances_CC[math.floor(dist / bin_size)] += 1
                            if (mol.numbers[i] == 6 and mol.numbers[j] == 1) or (mol.numbers[i] == 1 and mol.numbers[j] == 6):
                                distances_CH[math.floor(dist / bin_size)] += 1
                            if (mol.numbers[i] == 6 and mol.numbers[j] == 8) or (mol.numbers[i] == 8 and mol.numbers[j] == 6):
                                distances_CO[math.floor(dist / bin_size)] += 1

                    stats_list.append([
                        # model info
                        model,
                        version,
                        channels,
                        beta,
                        init,
                        # stats
                        num_unique,  # number of unique molecules
                        num_known,
                        num_valid,  # number of valid molecules (including duplicates)
                        num_valid_unique,  # number of unique and valid molecules
                        num_unique / 64,  # fraction of mols that are unique
                        num_valid / 64,  # fraction of mols that are unique
                        num_valid_unique / 64,  # fraction of mols that are valid & unique
                        stats_gen['valid_atoms'].sum() / stats_gen['n_atoms'].sum(),  # fraction of atoms that are valid
                        stats_gen['valid_H'].sum() / stats_gen['H'].sum(),
                        stats_gen['valid_C'].sum() / stats_gen['C'].sum(),
                        stats_gen['valid_N'].sum() / stats_gen['N'].sum(),
                        stats_gen['valid_O'].sum() / stats_gen['O'].sum(),
                        stats_gen['valid_F'].sum() / stats_gen['F'].sum(),
                        js_divergence(distances, distances_qm9),  # Jensen-Shannon divergence btwn distance distributions
                        atom_type_counts['H'],
                        atom_type_counts['C'],
                        atom_type_counts['N'],
                        atom_type_counts['O'],
                        atom_type_counts['F'],
                        *distances,
                        *distances_CC,
                        *distances_CH,
                        *distances_CO
                    ])



invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in double_scalars


invalid value encountered in doub

In [122]:
model_stats = pd.DataFrame(stats_list, columns=col_labels)
# model_stats.to_csv('generated_molecules_statistics.csv')
# model_stats.to_pickle('generated_molecules_statistics.pkl')

In [112]:
cond_model = model_stats['model']=='nequip'
cond_version = model_stats['version'] == 6
cond_channels = model_stats['channels'] == 128
cond_beta = model_stats['beta'] == 100
cond_init = model_stats['init'] == 'C6H5'
data = model_stats[cond_model & cond_version & cond_channels & cond_beta & cond_init].iloc[0, :]

print(f'For initial molecule {init}:')
print(f"{data['frac_valid_atoms'] * 100 :.2f}% of atoms are valid")
print(f"{data['frac_valid_H'] * 100 :.2f}% of hydrogens are valid")
print(f"{data['frac_valid_C'] * 100 :.2f}% of carbons are valid")
print(f"{data['frac_valid_N'] * 100 :.2f}% of nitrogens are valid")
print(f"{data['frac_valid_O'] * 100 :.2f}% of oxygens are valid")
print(f"{data['frac_valid_F'] * 100 :.2f}% of fluorines are valid")

For initial molecule C6H5:
75.67% of atoms are valid
99.25% of hydrogens are valid
54.20% of carbons are valid
0.00% of nitrogens are valid
15.79% of oxygens are valid
nan% of fluorines are valid


In [113]:
cond_model = model_stats['model']=='nequip'
cond_version = model_stats['version'] == 6
cond_channels = model_stats['channels'] == 128
cond_beta = model_stats['beta'] == 100
cond_init = model_stats['init'] == 'CH3'
data = model_stats[cond_model & cond_version & cond_channels & cond_beta & cond_init].iloc[0, :]

print(f'For initial molecule {init}:')
print(f"{data['frac_valid_atoms'] * 100 :.2f}% of atoms are valid")
print(f"{data['frac_valid_H'] * 100 :.2f}% of hydrogens are valid")
print(f"{data['frac_valid_C'] * 100 :.2f}% of carbons are valid")
print(f"{data['frac_valid_N'] * 100 :.2f}% of nitrogens are valid")
print(f"{data['frac_valid_O'] * 100 :.2f}% of oxygens are valid")
print(f"{data['frac_valid_F'] * 100 :.2f}% of fluorines are valid")

For initial molecule C6H5:
80.43% of atoms are valid
94.85% of hydrogens are valid
64.04% of carbons are valid
60.00% of nitrogens are valid
67.50% of oxygens are valid
nan% of fluorines are valid


In [114]:
cond_model = model_stats['model'] == 'nequip'
cond_version = model_stats['version'] == 6
cond_channels = model_stats['channels'] == 128
cond_beta = model_stats['beta'] == 100
cond_init = model_stats['init'] == 'CH3'
data = model_stats[cond_model & cond_version & cond_channels & cond_beta & cond_init].iloc[0, :]

fig = go.Figure(data=[
    go.Bar(
        x=atom_labels,
        y=normalize_histogram(list(elements_qm9.values())), 
        name='QM9'
    ),
    go.Bar(
        x=atom_labels,
        y=normalize_histogram(list(data[atom_labels])),
        name='Model'
    )
])
# Change the bar mode
fig.update_layout(barmode='group')
fig.update_layout(title=f'Atom Type Distribution for NequIP v{version}, {channels} channels, beta = {beta}, initial {init}')
fig.show()

In [115]:
cond_model = model_stats['model'] == 'nequip'
cond_version = model_stats['version'] == 6
cond_channels = model_stats['channels'] == 128
cond_beta = model_stats['beta'] == 100
cond_init = model_stats['init'] == 'C6H5'
data = model_stats[cond_model & cond_version & cond_channels & cond_beta & cond_init].iloc[0, :]

fig = go.Figure(data=[
    go.Bar(
        x=distance_labels,
        y=normalize_histogram(distances_qm9),
        name='QM9'
    ),
    go.Bar(
        x=distance_labels,
        y=normalize_histogram(list(data[distance_labels])),
        name='Model'
    )
])
# Change the bar mode
fig.update_layout(barmode='group', title='Interatomic Distance Distribution')
fig.update_xaxes(title='Distance (A)')
fig.show()

In [116]:
cond_model = model_stats['model'] == 'nequip'
cond_version = model_stats['version'] == 6
cond_channels = model_stats['channels'] == 128
cond_beta = model_stats['beta'] == 100
cond_init = model_stats['init'] == 'C6H5'
data = model_stats[cond_model & cond_version & cond_channels & cond_beta & cond_init].iloc[0, :]

fig = go.Figure(data=[
    go.Bar(
        x=distance_labels,
        y=normalize_histogram(distances_CC_qm9),
        name='QM9'
    ),
    go.Bar(
        x=distance_labels,
        y=normalize_histogram(list(data[distance_CC_labels])),
        name='Model'
    )
])
# Change the bar mode
fig.update_layout(barmode='group', title='C-C Distance Distribution')
fig.update_xaxes(title='Distance (A)')
fig.show()

In [117]:
# fraction of valid atoms across different # of channels/beta

cond_model = model_stats['model'] == 'nequip-l2'
cond_version = model_stats['version'] == 6
cond_init = model_stats['init'] == 'C6H5'

data = model_stats[cond_model & cond_version & cond_init][['channels', 'frac_valid_atoms', 'beta']]

px.line(data, x='channels', y='frac_valid_atoms', color='beta')

In [118]:
# same thing, but for specific atom types

cond_model = model_stats['model'] == 'nequip-l2'
cond_version = model_stats['version'] == 6
cond_init = model_stats['init'] == 'CH3'

data = model_stats[cond_model & cond_version & cond_init][['channels', 'frac_valid_C', 'beta']]

px.line(data, x='channels', y='frac_valid_C', color='beta')