In [1]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch_geometric as tg
from torch_geometric.loader import DataLoader
from pymongo import MongoClient
from chespex.encoding import Encoder, Decoder, Autoencoder, graph_loss, prediction_to_molecules

In [None]:
level = 2 # Change here: 0, 1, or 2
n_bead_types = [15, 45, 96][level]
bead_types = [
    ["Q", "P", "N", "C"],
    ["Q2", "Q1", "P3", "P2", "P1", "N3", "N2", "N1", "C3", "C2", "C1", "X2", "X1"],
    ["Q5", "Q4", "Q3", "Q2", "Q1", "P6", "P5", "P4", "P3", "P2", "P1", "N6", "N5", "N4", "N3", "N2", "N1", "C6", "C5", "C4", "C3", "C2", "C1", "X4", "X3", "X2", "X1"]
][level]
n_bead_classes = len(bead_types)

In [35]:
client = MongoClient("mongodb://localhost:27017")
database = client.get_database(f"molecules-4")
collection = database.get_collection(f"level-{level}")
molecule_list = list(collection.aggregate([{"$sample": {"size": 10}}]))

In [36]:
model_name = f"model-4-{n_bead_types}-5d.pt"
print(f"Loading model {model_name}")
model = torch.load(model_name, weights_only=False)

Loading model model-4-96-5d.pt


In [37]:
def as_torch_geometric_data(db_molecule):
    return tg.data.Data(x=torch.Tensor(db_molecule["node_features"]), edge_index=torch.Tensor(db_molecule["edge_index"]).to(torch.int64))

def transform(batch):
    node_type = nn.functional.one_hot(batch.x[:, 1].to(torch.int64), n_bead_classes)
    node_size = nn.functional.one_hot(batch.x[:, 0].to(torch.int64), 3)
    node_charge = nn.functional.one_hot(batch.x[:, 2].to(torch.int64), 3)[:,::2]
    node_oco_w_tfe = batch.x[:, 3:] / 10 # Octanol-water transfer free energy
    batch.x = torch.cat([node_type, node_size, node_charge, node_oco_w_tfe], dim=1)
    return batch.cuda()

In [38]:
with open("../bead_types/bead-types.json", "r") as bead_types_file:
    bead_class_names = np.array(list(set([b["name"].replace("+", "").replace("-", "") for b in json.load(bead_types_file)[1]])))
print(bead_class_names)
torch_molecules = [as_torch_geometric_data(molecule) for molecule in molecule_list]
torch_molecules = next(iter(DataLoader(torch_molecules, batch_size=len(torch_molecules), shuffle=False)))
torch_molecules = transform(torch_molecules)
predicted_nodes, predicted_adjacency_matrix = model(torch_molecules)
predicted_molecules = prediction_to_molecules(
    predicted_nodes, predicted_adjacency_matrix, bead_types
)
for prediction, molecule in zip(predicted_molecules, molecule_list):
    print("Prediction:", prediction, "Truth:", molecule["name"])

['X1' 'SP1' 'SP3' 'TP2' 'TX2' 'SC3' 'TQ2' 'SC1' 'TQ1' 'TC3' 'N1' 'SX1'
 'SN1' 'TC2' 'TC1' 'SN3' 'X2' 'C2' 'SQ2' 'SP2' 'SC2' 'SX2' 'TN3' 'TN1'
 'TX1' 'P2' 'TP1' 'P1' 'TP3' 'P3' 'SN2' 'C3' 'N2' 'Q1' 'Q2' 'N3' 'SQ1'
 'C1' 'TN2']
Prediction: SC6 SP2 SP3 TC5,0-1 1-2 2-3 Truth: SC6 SP2 SP3 TC5,0-1 1-2 2-3
Prediction: N1 SC3 SC5 X2,0-3 1-2 1-3 Truth: N1 SC3 SC5 X2,0-3 1-2 1-3
Prediction: Q5+ SQ4- TQ1+ TQ1+,0-1 0-3 1-2 1-3 2-3 Truth: Q4+ Q5+ TQ1- TQ1-,0-1 0-2 0-3 1-3
Prediction: N1 Q1+ SP3 TP5,0-1 1-2 1-3 2-3 Truth: N1 Q1- SP3 TP5,0-1 1-2 1-3 2-3
Prediction: Q5- SQ5- TC3 TP4,0-1 1-3 2-3 Truth: Q5- SQ5- TC3 TP4,0-1 1-3 2-3
Prediction: Q2+ Q4- TP1 TQ4-,0-1 0-3 2-3 Truth: Q2- Q4- SQ4- TP1,0-1 0-2 2-3
Prediction: C3 SC6 SX4 X4,0-1 0-2 0-3 1-2 1-3 Truth: C3 SC6 SX4 X4,0-2 0-3 1-3 2-3
Prediction: Q5- SN3 SQ4- TQ1+,0-1 1-2 2-3 Truth: Q5- SN3 SQ4+ TQ1+,0-1 1-2 2-3
Prediction: C1 SP1 TQ5+ TX4,0-1 0-3 2-3 Truth: C1 SP1 TQ5+ TX4,0-1 0-3 2-3
Prediction: SC4 SN5 TP1 TQ2-,0-1 0-2 0-3 2-3 Truth: SC4 SN5 TP1 

In [39]:
collection = database.get_collection(f"level-{level}")
molecule_list = list(collection.aggregate([{"$sample": {"size": 100_000}}]) if level > 0 else collection.find())
torch_molecules = [as_torch_geometric_data(molecule) for molecule in molecule_list]
dataloader = DataLoader(torch_molecules, batch_size=16384, shuffle=False, num_workers=2)
accurracy = np.zeros(5)
for batch in dataloader:
    batch = transform(batch)
    predicted_nodes, predicted_adjacency_matrix = model(batch)
    loss, accuracy = graph_loss(batch, predicted_nodes, predicted_adjacency_matrix, return_accuracy=True)
    accurracy[:-1] += accuracy * len(batch)
    accurracy[-1] += len(batch)
accurracy = accurracy / accurracy[-1]
print(f'Total accuracy: Edges: {accurracy[0]:.3f}, Class: {accurracy[1]:.3f}, Size: {accurracy[2]:.3f}, Charge: {accurracy[3]:.3f}')

Total accuracy: Edges: 0.986, Class: 0.979, Size: 0.986, Charge: 0.998
