In [1]:
%pip install torch numpy ase torch_geometric scikit-learn pandas rdkit

Collecting torch
  Downloading torch-2.2.2-cp311-none-macosx_11_0_arm64.whl.metadata (25 kB)
Collecting numpy
  Using cached numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl.metadata (114 kB)
Collecting ase
  Using cached ase-3.22.1-py3-none-any.whl.metadata (3.1 kB)
Collecting torch_geometric
  Using cached torch_geometric-2.5.2-py3-none-any.whl.metadata (64 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.4.2-cp311-cp311-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting pandas
  Downloading pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl.metadata (19 kB)
Collecting rdkit
  Downloading rdkit-2023.9.5-cp311-cp311-macosx_11_0_arm64.whl.metadata (3.9 kB)
Collecting filelock (from torch)
  Using cached filelock-3.13.4-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Using cached sympy-1.12-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.3-py3-none-any.whl.metadata (5.1 kB)
Collecting jinja2 (from torch)
  Using cached Jin

In [13]:
from contextlib import contextmanager
import sys
import os
import torch
from ase import Atoms
from ase.data import chemical_symbols
from ase.calculators.morse import MorsePotential
from ase.optimize import QuasiNewton
import numpy as np

In [3]:
print(os.getcwd())
os.path.exists('../data')

/Users/varunhegde/Documents/Georgia Tech/Spring 2024/MLC/MLC_Final_Project


True

In [5]:
# Loading data
data_dir = os.path.join('..', 'data')
databio_path = os.path.join(data_dir, 'datalist_bio.pt')
datatox_path = os.path.join(data_dir, 'datalist_tox.pt')

datalist_bio = torch.load(databio_path)
datalist_tox = torch.load(datatox_path)

# test that it works
print("Biodegradability Data:")
for data in datalist_bio:
    print(data)
    break

print("Toxicity Data:")
for data in datalist_tox:
    print(data)
    break

Biodegradability Data:
Data(x=[9, 100], formula='C8H7Br', name='', positions=[9, 3], edge_index=[2, 9], edge_attr=[9, 4], bio=[1], tox=[0])
Toxicity Data:
Data(x=[34, 100], formula='C27H25ClN6', name='NCGC00178831-03', positions=[34, 3], edge_index=[2, 37], edge_attr=[37, 4], bio=[0], tox=[1])


In [53]:
datapoint = datalist_tox[0]
def get_molecule(datapoint):
    atomic_numbers = datapoint.x.argmax(dim=1).tolist()
    atomic_symbols = [chemical_symbols[number] for number in atomic_numbers]
    # positions = datapoint.positions.numpy()
    positions = np.random.rand(len(atomic_numbers), 3) * 10  # Random positions within a 10x10x10 Å box
    
    molecule = Atoms(numbers=atomic_numbers, positions=positions)
    return molecule

molecule = get_molecule(datapoint)
print(molecule)

Atoms(symbols='SBCB3CB8CB2CB3CB8CB2', pbc=False)


In [20]:
@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:  
            yield
        finally:
            sys.stdout = old_stdout

In [54]:
def get_dft(molecule, De=0.242, re=0.74, alpha=1.5):
    dft_calculator = MorsePotential(De=De, r0=re, alpha=alpha)
    molecule.set_calculator(dft_calculator)
    energy = molecule.get_potential_energy()
    print(energy)
    with suppress_stdout():
        opt = QuasiNewton(molecule)
        opt.run(fmax=0.02)  # Converge forces to less than 0.02 eV/Angstrom
    # Check the optimized bond length and energy
    optimized_energy = molecule.get_potential_energy()
    return optimized_energy

energy = get_dft(molecule)
print(f'The DFT calculated energy is: {energy} eV')

36.53020122411837
The DFT calculated energy is: -12.000322025798228 eV
