In [127]:
# %pip install torch numpy ase torch_geometric scikit-learn pandas rdkit

In [103]:
from contextlib import contextmanager
import sys
import os
import torch
from ase import Atoms
from ase.data import chemical_symbols
from ase.calculators.morse import MorsePotential
from ase.optimize import QuasiNewton

In [8]:
print(os.getcwd())
os.path.exists('../../data')

/Users/varunhegde/Documents/Georgia Tech/Spring 2024/MLC/MLC_GAN_ToxBio/models/stability


True

In [14]:
# Loading data
data_dir = os.path.join('..', '..', 'data')
databio_path = os.path.join(data_dir, 'datalist_bio.pt')
datatox_path = os.path.join(data_dir, 'datalist_tox.pt')

datalist_bio = torch.load(databio_path)
datalist_tox = torch.load(datatox_path)

# test that it works
print("Biodegradability Data:")
for data in datalist_bio:
    print(data)
    break

print("Toxicity Data:")
for data in datalist_tox:
    print(data)
    break

Biodegradability Data:
Data(x=[9, 100], formula='C8H7Br', name='', positions=[9, 3], edge_index=[2, 9], edge_attr=[9, 4], bio=[1], tox=[0])
Toxicity Data:
Data(x=[34, 100], formula='C27H25ClN6', name='NCGC00178831-03', positions=[34, 3], edge_index=[2, 37], edge_attr=[37, 4], bio=[0], tox=[1])


In [124]:
datapoint = datalist_tox[500]
def get_molecule(datapoint):
    atomic_numbers = datapoint.x.argmax(dim=1).tolist()
    atomic_symbols = [chemical_symbols[number] for number in atomic_numbers]
    # print('Atomic symbols:', atomic_symbols)

    positions = datapoint.positions.numpy()
    # print('Positions:', positions.shape)

    # molecule = Atoms(symbols=atomic_symbols, positions=positions)
    molecule = Atoms(numbers=atomic_numbers, positions=positions)
    return molecule

molecule = get_molecule(datapoint)
print(molecule)

Atoms(symbols='C11B62N12', pbc=False)


In [125]:
@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:  
            yield
        finally:
            sys.stdout = old_stdout

In [126]:
def get_dft(molecule, De=0.242, re=0.74, alpha=1.5):
    dft_calculator = MorsePotential(De=De, r0=re, alpha=alpha)
    molecule.set_calculator(dft_calculator)
    energy = molecule.get_potential_energy()
    with suppress_stdout():
        opt = QuasiNewton(molecule)
        opt.run(fmax=0.02)  # Converge forces to less than 0.02 eV/Angstrom
        # Check the optimized bond length and energy
        optimized_energy = molecule.get_potential_energy()
    return optimized_energy

energy = get_dft(molecule)
print(f'The DFT calculated energy is: {energy} eV')

The DFT calculated energy is: -183.669265825246 eV
