In [140]:
from contextlib import contextmanager
import sys
import os
import torch
from ase import Atoms
from ase.data import chemical_symbols
from ase.calculators.morse import MorsePotential
from ase.optimize import QuasiNewton
import numpy as np
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from ase import Atoms
from ase.io import read
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx
# from torch_geometric.data import DataLoader, TensorDataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from ase.build import molecule
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import random
from torch_geometric.nn import global_add_pool, GATConv, CGConv, GCNConv, RGCNConv
from torch_geometric.nn.models.schnet import GaussianSmearing
from sklearn.metrics import roc_auc_score, precision_score, confusion_matrix
import matplotlib.pyplot as plt
import pickle
from collections import Counter
import seaborn as sns
import time
from torch.optim.lr_scheduler import StepLR

In [141]:
with open('data/bio_data_rcgn.pkl', 'rb') as file:
    bio_data = pickle.load(file)
# 0 is non-biodegradable, 1 is biodegradable
with open('data/bio_labels_rcgn.pkl', 'rb') as file:
    bio_labels = pickle.load(file)

In [142]:
datapoint = bio_data[0]
label = bio_labels[0]
print(datapoint)
print(label)

Data(x=[15, 10], edge_index=[2, 14], edge_attr=[14], atom_data=[15, 35])
1


In [147]:
def get_molecule(datapoint):
    atomic_numbers = datapoint.x.argmax(dim=1).tolist()
    # positions = datapoint.positions.numpy()
    positions = np.random.rand(len(atomic_numbers), 3) * 10  # Random positions within a 10x10x10 Å box
    
    molecule = Atoms(numbers=atomic_numbers, positions=positions)
    return molecule

@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:  
            yield
        finally:
            sys.stdout = old_stdout

def get_dft(datapoint, De=0.242, re=0.74, alpha=1.5):
    molecule = get_molecule(datapoint)
    print(molecule)
    dft_calculator = MorsePotential(De=De, r0=re, alpha=alpha)
    molecule.set_calculator(dft_calculator)
    # energy = molecule.get_potential_energy()
    with suppress_stdout():
        opt = QuasiNewton(molecule)
        opt.run(fmax=0.02)
    optimized_energy = molecule.get_potential_energy()
    pos = molecule.get_positions()
    return optimized_energy

energy = get_dft(datapoint)

print(f'DFT Energy: {energy} eV')

Atoms(symbols='X8He3X4', pbc=False)
[[5.662457   1.87452294 8.77553164]
 [1.63353332 5.91587123 6.12384318]
 [6.6504955  5.36089796 2.06869462]
 [3.5966397  0.4325894  8.22617301]
 [0.87490489 2.50358868 3.23672986]
 [9.88519522 4.77008411 9.33492312]
 [7.08698532 7.50168302 6.51080645]
 [3.22368708 7.32855699 9.79528747]
 [5.86759739 6.17763367 8.44311945]
 [7.3522386  7.58838951 1.29180116]
 [1.85930173 4.67549701 7.83789269]
 [5.12497022 2.66141813 4.82120803]
 [9.08052918 8.21423991 6.18530514]
 [4.93075361 5.3464214  5.46693267]
 [4.26764875 1.92683295 0.86395847]]
DFT Energy: 0.0 eV


In [144]:
class GraphEncoder(nn.Module):
    def __init__(self, input_dim, num_relations):
        super(GraphEncoder, self).__init__()
        self.rgcnconv1 = RGCNConv(input_dim, 512, num_relations=num_relations)
        self.rgcnconv2 = RGCNConv(512, 256, num_relations=num_relations)
        self.rgcnconv3 = RGCNConv(256, 128, num_relations=num_relations)

    def forward(self, x, edge_index, edge_attr):
        x = self.rgcnconv1(x, edge_index, edge_attr)
        x = nn.ReLU()(x)
        x = self.rgcnconv2(x, edge_index, edge_attr)
        x = nn.ReLU()(x)
        x = self.rgcnconv3(x, edge_index, edge_attr)
        x = nn.ReLU()(x)
        return x

class BioClassifier(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(BioClassifier, self).__init__()
        
        self.encoder = GraphEncoder(input_dim, num_heads)
        self.fc1 = nn.Linear(128, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 1)
        self.dropout = nn.Dropout(0.4)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = x.float()

        x = self.encoder(x, edge_index, edge_attr)
        x = global_add_pool(x, data.batch)
        x = self.dropout(x)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x.squeeze(dim=1)

In [145]:
model = BioClassifier(10, 4)

state_dict = torch.load('models/rcgn_model_7428.pt')
model.load_state_dict(state_dict)

model.eval()
def get_model_prediction(datapoint):
    pred = model(datapoint)
    return pred.item()

In [146]:
dft = get_dft(datapoint)
y = get_model_prediction(datapoint) * 100

print(f'DFT (remember negative is good): {dft}')
print(f'Chance of Biodgradability: {y:.2f}%')

Atoms(symbols='X8He3X4', pbc=False)
DFT (remember negative is good): -6.110760815807202e-06
Chance of Biodgradability: 68.77%
