In [1]:
%load_ext autoreload
%autoreload 2
from torch_geometric.data import Data, Batch
from ogb.utils.features import (atom_to_feature_vector,
 bond_to_feature_vector) 
from rdkit import Chem
from rdkit.Chem import rdDetermineBonds
import numpy as np
import torch
import os 
import sys
import matplotlib.pyplot as plt

In [2]:
def get_mol_objects(filename):
    mol_list = []
    energy_list = []
   # energies = 
    with open(filename, "r") as f:
        lines = f.readlines()
        file_str  = "".join(lines)
        atom_num = lines[0]
    xyz_list = file_str.split(atom_num)[1:]
    for i in range(len(xyz_list)):
        x = Chem.MolFromXYZBlock(f'{atom_num.strip()}\n' + xyz_list[i])
        Chem.rdDetermineBonds.DetermineConnectivity(x)
        energy = float(xyz_list[i].split('\n')[0].strip())
        energy_list.append(energy)
        mol_list.append(x)
    return mol_list, energy_list

In [3]:
import contextlib
import os
import re
import subprocess
import warnings

def _get_energy(file):
    normal_termination = False
    with open(file) as f:
        for l in f:
            if "TOTAL ENERGY" in l:
                try:
                    energy = float(re.search(r"[+-]?(?:\d*\.)?\d+", l).group())
                except:
                    return np.nan
            if "normal termination of xtb" in l:
                normal_termination = True
    if normal_termination:
        return energy
    else:
        return np.nan

def run_gfn_xtb(
    filepath,
    filename,
    gfn_version="gfnff",
    opt=False,
    gfn_xtb_config: str = None,
    remove_scratch=True,
):
    """
    Runs GFN_XTB/FF given a directory and either a coord file or all coord files will be run

    :param filepath: Directory containing the coord file
    :param filename: if given, the specific coord file to run
    :param gfn_version: GFN_xtb version (default is 2)
    :param opt: optimization or single point (default is opt)
    :param gfn_xtb_config: additional xtb config (default is None)
    :param remove_scratch: remove xtb files
    :return:
    """
    xyz_file = os.path.join(filepath, filename)

    # optimization vs single point
    if opt:
        opt = "--opt"
    else:
        opt = ""

    # cd to filepath
    starting_dir = os.getcwd()
    os.chdir(filepath)

    file_name = str(xyz_file.split(".")[0])
    cmd = "xtb --{} {} {} {}".format(
        str(gfn_version), xyz_file, opt, str(gfn_xtb_config or "")
    )

    # run XTB
    with open(file_name + ".out", "w") as fd:
        subprocess.run(cmd, shell=True, stdout=fd, stderr=subprocess.STDOUT)

    # check XTB results
    if os.path.isfile(os.path.join(filepath, "NOT_CONVERGED")):
        # optimization not converged
        warnings.warn(
            "xtb --{} for {} is not converged, using last optimized step instead; proceed with caution".format(
                str(gfn_version), file_name
            )
        )

        # remove files
        if remove_scratch:
            os.remove(os.path.join(filepath, "NOT_CONVERGED"))
            os.remove(os.path.join(filepath, "xtblast.xyz"))
            os.remove(os.path.join(filepath, file_name + ".out"))
        energy = np.nan

    elif opt and not os.path.isfile(os.path.join(filepath, "xtbopt.xyz")):
        # other abnormal optimization convergence
        warnings.warn(
            "xtb --{} for {} abnormal termination, likely scf issues, using initial geometry instead; proceed with caution".format(
                str(gfn_version), file_name
            )
        )
        if remove_scratch:
            os.remove(os.path.join(filepath, file_name + ".out"))
        energy = np.nan

    else:
        # normal convergence
        # get energy
        energy = _get_energy(file_name + ".out")
        if remove_scratch:
            with contextlib.suppress(FileNotFoundError):
                os.remove(os.path.join(filepath, file_name + ".out"))
                os.remove(os.path.join(filepath, "gfnff_charges"))
                os.remove(os.path.join(filepath, "gfnff_adjacency"))
                os.remove(os.path.join(filepath, "gfnff_topo"))
                os.remove(os.path.join(filepath, "xtbopt.log"))
                os.remove(os.path.join(filepath, "xtbopt.xyz"))
                os.remove(os.path.join(filepath, "xtbtopo.mol"))
                os.remove(os.path.join(filepath, "wbo"))
                os.remove(os.path.join(filepath, "charges"))
                os.remove(os.path.join(filepath, "xtbrestart"))
    os.chdir(starting_dir)
    return energy

In [4]:
def get_mol_path(mol_idx, solvation=False):
    if solvation:
        return os.path.join(os.getcwd(), 'conformers', f'molecule_{mol_idx}', 'solvation')
    else:
        return os.path.join(os.getcwd(), 'conformers', f'molecule_{mol_idx}', 'vacuum')
get_mol_path(0)

'/home/radoslavralev/Documents/Thesis/gfn-diffusion/energy_sampling/notebooks/conformers/molecule_0/vacuum'

In [5]:
torch.ones(32, 69).reshape(-1, 23, 3).shape

torch.Size([32, 23, 3])

In [6]:
def xyz_mol2graph(xyz_mol):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """

    mol = xyz_mol
    mol = Chem.AddHs(mol)
    #mol = Chem.RemoveHs(mol)
    pos = mol.GetConformer().GetPositions()
    #print(mol)
    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append([
            atom.GetAtomicNum(),
            # int(atom.GetChiralTag()),
            atom.GetTotalDegree(),
            # atom.GetFormalCharge(),
            atom.GetTotalNumHs(),
            atom.GetNumRadicalElectrons(),
            int(atom.GetHybridization()),
            # int(atom.GetIsAromatic()),
            # int(atom.IsInRing())
        ])
    x = np.array(atom_features_list, dtype = np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype = np.int64)

    else:   # mol has no bonds
        print('Mol has no bonds :()')
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

    graph = dict()
    graph['edge_index'] = edge_index
    graph['edge_feat'] = edge_attr
    graph['node_feat'] = x
    graph['pos'] = pos
    graph['num_nodes'] = len(x)
    graph['num_bonds'] = len(mol.GetBonds())

    return graph 

In [7]:
def prep_input(graph, pos=None, device=None):
    datalist = []
    for xyz in pos:
        if pos is not None:
            graph['pos'] = xyz  
        #print('Number of atoms:', graph['node_feat'].shape[0])
        data = Data(
            atoms=torch.from_numpy(graph['node_feat']), 
            edge_index=torch.from_numpy(graph['edge_index']), 
            edge_attr=torch.from_numpy(graph['edge_feat']), 
            pos=graph['pos'],).to(device)
        data.validate(raise_on_error=True)
        datalist.append(data)
    #batch = Batch.from_data_list(datalist).to(device)
    return datalist

In [8]:
def extract_graphs(filename):
    mol_objects, mol_ens = get_mol_objects(filename)
    datalist = []
    for mol, en in zip(mol_objects, mol_ens):
        graph = xyz_mol2graph(mol)
        if graph['num_bonds'] == 0:
            continue
        data = Data(
            atoms=torch.from_numpy(graph['node_feat']), 
            edge_index=torch.from_numpy(graph['edge_index']), 
            edge_attr=torch.from_numpy(graph['edge_feat']), 
            pos=torch.from_numpy(graph['pos']),
            y=torch.tensor(en))
        data.validate(raise_on_error=True)
        datalist.append(data)
    return datalist

In [20]:
# calculate free energy for all pairs of conformers
from openmm import unit

kB = unit.BOLTZMANN_CONSTANT_kB.value_in_unit(unit.hartree/unit.kelvin)
T = 298.15
beta = 1/(T)

def solvation_en(index=0):
    solvent_dir = os.path.join('../..', 'conformation_sampling', 'conformers', f'molecule_{index}', 'solvation', 'crest_conformers.xyz')
    vacuum_dir = os.path.join('../..', 'conformation_sampling', 'conformers', f'molecule_{index}', 'vacuum', 'crest_conformers.xyz')
    _, solv_en = get_mol_objects(solvent_dir)
    _, vac_en = get_mol_objects(vacuum_dir)
    #print(solv_en, vac_en)
    # avg_solv_en = np.mean(np.exp(np.array(solv_en)))
    # avg_vac_en = np.mean(np.exp(np.array(vac_en))) 
    # free_en = avg_vac_en - avg_solv_en
    if len(solv_en) < 100:
        return None
    free_en_vac = -1/beta*(np.logaddexp.reduce(-np.array(vac_en) * beta ) - np.log(len(vac_en)) )
    free_en_solv = -1/beta*(np.logaddexp.reduce(-np.array(solv_en) * beta ) - np.log(len(solv_en)))
    free_en = (free_en_solv-free_en_vac) * 627.503
    # free_en = -1/627.583*np.logaddexp.reduce(-(np.array(solv_en[:l])-np.array(vac_en[:l])) * 627.503) - np.log(l)
    return free_en, free_en_solv, free_en_vac
    
for i in range(600):
    x = solvation_en(i)
    if x:
        free_en, free_en_solv, free_en_vac = x 
        print(f'Molecule_{i}, FED: {free_en}')
        print(f'Molecule_{i}, FE_solv: {free_en_solv}')
        print(f'Molecule_{i}, FE_vac: {free_en_vac}')
    
    

Molecule_0, FED: -4.3662996425853695
Molecule_0, FE_solv: -30.282859291173285
Molecule_0, FE_vac: -30.275901077997275
Molecule_19, FED: -2.0286040824863227
Molecule_19, FE_solv: -30.379585941689914
Molecule_19, FE_vac: -30.376353121954413
Molecule_20, FED: -8.962879412756019
Molecule_20, FE_solv: -39.451050858896366
Molecule_20, FE_vac: -39.43676745401582
Molecule_23, FED: -6.259477393408549
Molecule_23, FE_solv: -42.41013694423078
Molecule_23, FE_vac: -42.40016172914271
Molecule_30, FED: -4.2088378125435
Molecule_30, FE_solv: -32.52345043068814
Molecule_30, FE_vac: -32.51674315149977
Molecule_35, FED: -4.615489767255099
Molecule_35, FE_solv: -33.45115739806113
Molecule_35, FE_vac: -33.443802071047145
Molecule_50, FED: -8.23995955714413
Molecule_50, FE_solv: -40.57952742980173
Molecule_50, FE_vac: -40.56639608292826
Molecule_61, FED: -8.161039439660186
Molecule_61, FE_solv: -56.77390663237165
Molecule_61, FE_vac: -56.76090105400842
Molecule_79, FED: -3.735943377537626
Molecule_79, FE_s

KeyboardInterrupt: 

In [10]:
solvent_data = {}
vacuum_data = {}
for dir in os.listdir(os.path.join(os.getcwd(),'..', '..', 'conformation_sampling', 'conformers')):
    solvent_data[dir] = []
    vacuum_data[dir] = []

for dir in os.listdir(os.path.join(os.getcwd(),'..', '..', 'conformation_sampling', 'conformers')):
    solvent_dir = os.path.join(os.getcwd(), '..', '..','conformation_sampling', 'conformers', dir, 'solvation', 'crest_conformers.xyz')
    vacuum_dir = os.path.join(os.getcwd(),'..', '..','conformation_sampling', 'conformers', dir, 'vacuum', 'crest_conformers.xyz')
    solvent_graphs = extract_graphs(solvent_dir)
    vacuum_graphs = extract_graphs(vacuum_dir)
    solvent_data[dir].extend(solvent_graphs)
    vacuum_data[dir].extend(vacuum_graphs)

In [21]:
import sys
import os
import torch
import torch.nn.functional as F

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from models.ImplicitSGNN.GNN_Trainer import Trainer
from models.ImplicitSGNN.GNN_Models import GNN3_scale_64

pt_file = '../weights/test_set.pt' # Needs to be downloaded from ETH Research Collection
gnn_file = '../weights/GNN_state_dict.pt'

def calculate_force_loss_only(pre_energy, pre_forces, ldata):
    loss = F.mse_loss(pre_forces, ldata.forces)
    return loss


trainer = Trainer(verbose=False,name='GNN3_pub_vis',path='trained_models',force_mode=True,enable_tmp_dir=False,random_state=3)

device = 'cuda'
trainer.explicit = True
print('load data',flush=True)
gbneck_parameters, unique_radii = trainer.prepare_training_data_from_pt_file(pt_file)
print('data loaded',flush=True)
trainer.model = GNN3_scale_64(device='cuda', unique_radii=unique_radii, parameters=gbneck_parameters)
trainer.model.load_state_dict(torch.load(gnn_file))
trainer.set_lossfunction(calculate_force_loss_only)
# loss = trainer.test_model()
# pre,exp = trainer.return_test_set_predictions()
model_solv = trainer.model.to('cuda')

load data
data loaded


In [None]:
trainer.

In [22]:
g = list(solvent_data.values())[0][0]
g

Data(edge_index=[2, 32], edge_attr=[32, 3], y=-23.131166458129883, pos=[16, 3], atoms=[16, 5])

In [23]:

g, model_solv(g.to('cuda'))

IndexError: index 6 is out of bounds for dimension 1 with size 5

In [None]:
# count total graphs in each dataset
solvent_count = 0
vacuum_count = 0
for key in solvent_data.keys():
    solvent_count += len(solvent_data[key])
    vacuum_count += len(vacuum_data[key])
print(solvent_count, vacuum_count)

In [None]:
vacuum_data['molecule_0'][0]

In [None]:
#calculate max bond length
def max_atom_dist(graphs):
    max_len = 0
    for graph in graphs:
        pos = graph.pos
        dist = torch.cdist(pos, pos)
        if dist.max() > max_len:
            max_len = dist.max()
        # make cdist into list of distances
        distances = []
        for i in range(dist.shape[0]):
            for j in range(dist.shape[1]):
                if i != j:
                    distances.append(dist[i,j])
    return distances
lens = []
for i in range(640):
    distances = max_atom_dist(solvent_data[f'molecule_{i}'])
    lens.extend(distances)


In [None]:
np.mean(lens), np.std(lens), np.min(lens), np.max(lens)


In [None]:
def plot_graph(graph, pos):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    pos = pos.reshape(-1, len(graph.atoms), 3)
    #node_features = 

    # Generate a color map based on node features
    colors = graph.atoms[:, 0]

    ax.scatter(pos[0, :, 0], pos[0, :, 1], pos[0, :, 2], c=colors, s=100)
    for edge in graph['edge_index'].T:
        ax.plot(pos[0, edge, 0], pos[0, edge, 1], pos[0, edge, 2], color='black')
    plt.show()

mol = vacuum_data['molecule_2'][0]
plot_graph(mol, mol.pos)

In [None]:
from sklearn.preprocessing import StandardScaler

# import standard scaler, scale solvent_data.y and vacuum_data.y and print the mean and variance of the scaler
scaler = StandardScaler()
solvent_y = [[data for data in data_list] for data_list in solvent_data.values()]
vacuum_y = [[data for data in data_list] for data_list in vacuum_data.values()]
solvent_y = [item.y for sublist in solvent_y for item in sublist]
vacuum_y = [item.y for sublist in vacuum_y for item in sublist]

scaled_solvent_y = scaler.fit_transform(np.array(solvent_y).reshape(-1, 1))
scaled_vacuum_y = scaler.fit_transform(np.array(vacuum_y).reshape(-1, 1))
print("Mean of solvent_data.y:", np.array(solvent_y).reshape(-1, 1).mean())
print("Variance of solvent_data.y:", np.array(solvent_y).reshape(-1, 1).var())
print("Mean of vacuum_data.y:", np.array(vacuum_y).reshape(-1, 1).mean())
print("Variance of vacuum_data.y:", np.array(vacuum_y).reshape(-1, 1).var())
print('=============================')
print("Mean of scaled solvent_data.y:", scaled_solvent_y.mean())
print("Variance of scaled solvent_data.y:", scaled_solvent_y.var())
print("Mean of scaled vacuum_data.y:", scaled_vacuum_y.mean())
print("Variance of scaled vacuum_data.y:", scaled_vacuum_y.var())

In [None]:
import sys
import os 
import torch
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
from models import EGNNModel, MACEModel
import json

def load_model(model, filename):
    with open(filename + '.json', 'r') as f:
        model_args = json.load(f)
        if model == 'egnn':
            model = EGNNModel(in_dim=model_args['in_dim'][0], emb_dim=model_args['emb_dim'], out_dim=model_args['out_dim'], num_layers=model_args['num_layers'], num_atom_features=model_args['in_dim'], equivariant_pred=True)
        elif model == 'mace':
            model = MACEModel(in_dim=model_args['in_dim'], emb_dim=model_args['emb_dim'], out_dim=model_args['out_dim'], num_layers=model_args['num_layers'], equivariant_pred=True)
    model.load_state_dict(torch.load(filename + '.pt'))
    return model, model_args

In [None]:
vac_file = '/home/radoslavralev/Documents/Thesis/gfn-diffusion/energy_sampling/weights/egnn_vacuum_small_with_hs_32'
sol_file = '/home/radoslavralev/Documents/Thesis/gfn-diffusion/energy_sampling/weights/egnn_solvation_small_with_hs_32'
vac_model, vac_args = load_model('egnn', vac_file)
sol_model, sol_args = load_model('egnn', sol_file)

In [None]:
import matplotlib.pyplot as plt

sol_model.eval()
solvent_energies = {}
vacuum_energies = {}
with torch.no_grad():
    for key, values in solvent_data.items():
        solvent_energies[key] = []
        vacuum_energies[key] = []
        for x in values:
            solvent_energy = sol_model(x)
            solvent_energies[key].append(solvent_energy.item())
    for key, values in vacuum_data.items():
        for x in values:
            vacuum_energy = vac_model(x)
            vacuum_energies[key].append(vacuum_energy.item())

In [None]:
import matplotlib.pyplot as plt
mol_sample = list(vacuum_data.values())[0][0].to('cuda')
# run vac_model on mol_sample 100 times, add increasing levels of gaussian noise to the input features and plot the output
vac_model.eval()
vac_model = vac_model.to('cuda')
energies = []
for i in range(10000):
    mol_sample.pos = mol_sample.pos + torch.randn_like(mol_sample.pos)
    energy = vac_model(mol_sample)
    energies.append(energy.item())
# add axis descriptions
plt.ylabel('Energy')
plt.xlabel('Noise level')
plt.plot(energies)
plt.show()  

In [None]:
# same for solvation model
sol_model.eval()
sol_model = sol_model.to('cuda')
energies = []
for i in range(10000):
    mol_sample.pos = mol_sample.pos + torch.randn_like(mol_sample.pos)
    energy = sol_model(mol_sample)
    energies.append(energy.item())
# add axis descriptions
plt.ylabel('Energy')
plt.xlabel('Noise level')
plt.plot(energies)
plt.show()

In [None]:
# find molecule with most conformers
conformer_num = []
max_conformers = 0
for key, values in solvent_data.items():
    if len(values) > max_conformers:
        max_conformers = len(values)
        print(key, max_conformers)
    conformer_num.append(len(values))

In [None]:
sum(conformer_num), sum(np.array(conformer_num)[np.array(conformer_num) > 20])

In [None]:
# plot conformer num with ylog
plt.plot(sorted(conformer_num))
plt.yscale('log')
plt.ylabel('Number of conformers')
plt.show()

In [None]:
tar_188

In [None]:
# get the y values for all molecules not just 188
tar = []
for key, values in vacuum_data.items():
    if key != 'molecule_188':
        continue
    for x in values:
        tar.append(x.y.item())

tar = np.array(tar)
plt.hist(tar, bins=30)
# plot mean and add the mean to the legend
plt.legend(['Energy', f'Mean={np.mean(tar)}'])
plt.yscale('linear')
plt.show()
np.mean(tar), np.std(tar)

In [None]:
p = vacuum_data['molecule_188'][0].pos 
# calculate pairwise distances of atoms in the molecule
dist = torch.cdist(p, p, p=2)
# plot the pairwise distances
plt.imshow(dist.cpu().numpy())
plt.colorbar()
plt.show()

In [None]:
q = torch.randn_like(p) / 10
# calculate pairwise distances of atoms in the molecule
dist = torch.cdist(q, q, p=2)
# plot the pairwise distances
plt.imshow(dist.cpu().numpy())
plt.colorbar()
plt.show()

In [None]:
from copy import copy
x = Batch.from_data_list(copy(vacuum_data['molecule_188'][:64]))
# calculate pairwise distances of atoms in the molecule
energies = []
for i in range(16):
    x.pos = (torch.randn(64*33, 3))
    x = x.to('cuda')
    e = vac_model(x).squeeze()
    energies.extend(e.detach().cpu().numpy().tolist())
# plot energy
plt.hist(energies, bins=30)
plt.hist(x.y.cpu().numpy(), bins=30)
plt.legend(['Energy', 'True Energy'])
plt.show()

In [None]:
energies = []
for mol in vacuum_data['molecule_0']:
    e = vac_model(mol.to('cuda')).squeeze()
    energies.append(e.detach().cpu().numpy())
# plot energy
plt.hist(np.array(energies)*627.503, bins=30)
plt.hist([mol.y.cpu()*627.503 for mol in vacuum_data['molecule_0']], bins=30)
plt.legend(['Predicted nergy', 'True Energy'])
plt.show()
np.unique(energies)

In [None]:
errors_solvent = []
for i, (key, values) in enumerate(solvent_energies.items()):
    pred_energy = np.array(values) * 627.503
    true_energy = np.array([point.y for point in solvent_data[key]]) * 627.503
    error = np.abs(pred_energy - true_energy)
    errors_solvent.extend(error)
np.mean(errors_solvent)

In [None]:
errors_vacuum = []
for i, (key, values) in enumerate(vacuum_energies.items()):
    pred_energy = np.array(values) * 627.503
    true_energy = np.array([point.y.cpu() for point in vacuum_data[key]]) * 627.503
    if pred_energy.shape[0] != true_energy.shape[0]:
        print(i, key, pred_energy, true_energy)
    error = np.abs(pred_energy - true_energy)
    errors_vacuum.extend(error)
np.mean(errors_vacuum)

In [None]:
# loop over "solvent_data" and "vacuum_data", take the mean of the energies and plot them (it's the .y) in a single plot
fig = plt.figure()
gt_solvent = []
gt_vacuum = []
for i, (key, values) in enumerate(solvent_data.items()):
    solvent_energy = np.array([x.y for x in values])
    vacuum_energy = np.array([x.y.cpu() for x in vacuum_data[key]])
    gt_solvent.extend(solvent_energy)
    gt_vacuum.extend(vacuum_energy)

# make the lines transparent 
plt.plot((np.array([np.mean(x) for x in solvent_energies.values()]) - np.array([np.mean(x) for x in vacuum_energies.values()]) ) * 627.503, label='Predicted Solvent', alpha=0.5)
plt.show()

In [None]:
# do the same plot for the predicted ones
pred_solvent = []
pred_vacuum = []
for i, (key, values) in enumerate(solvent_data.items()):
    pred_solvent.extend(solvent_energies[key])
    pred_vacuum.extend(vacuum_energies[key])

In [None]:
#  set nans to 0
gt_vacuum = np.nan_to_num(gt_vacuum) * 627.503
gt_solvent = np.nan_to_num(gt_solvent) * 627.503
pred_vacuum = np.nan_to_num(pred_vacuum) * 627.503
pred_solvent = np.nan_to_num(pred_solvent) * 627.503


In [None]:
np.abs(np.array(gt_vacuum) - np.array(pred_vacuum)).mean(), np.abs(np.array(gt_solvent) - np.array(pred_solvent)).mean()

In [None]:
from sklearn.linear_model import LinearRegression

import matplotlib.pyplot as plt


# Regression for Vacuum Energies
vacuum_regression = LinearRegression()
vacuum_regression.fit(np.array(gt_vacuum).reshape(-1, 1), np.array(pred_vacuum))
vacuum_regression_line = vacuum_regression.predict(np.array(gt_vacuum).reshape(-1, 1))
vacuum_errors = np.abs(np.array(pred_vacuum) - vacuum_regression_line)

plt.figure(figsize=(21, 7))
plt.subplot(1, 3, 1)
plt.scatter(gt_vacuum, pred_vacuum, alpha=0.5)
plt.plot(gt_vacuum, vacuum_regression_line, 'r--')
plt.errorbar(gt_vacuum, pred_vacuum, yerr=vacuum_errors, fmt='o', alpha=0.5)
plt.xlabel('Ground Truth Vacuum Energies')
plt.ylabel('Predicted Vacuum Energies')
plt.title('Vacuum: Predicted vs Ground Truth (MAE: {:.4f})'.format(np.mean(vacuum_errors)))
plt.axis('equal')

# Regression for Solvent Energies
solvent_regression = LinearRegression()
solvent_regression.fit(np.array(gt_solvent).reshape(-1, 1), np.array(pred_solvent))
solvent_regression_line = solvent_regression.predict(np.array(gt_solvent).reshape(-1, 1))
solvent_errors = np.abs(np.array(pred_solvent) - solvent_regression_line)

plt.subplot(1, 3, 2)
plt.scatter(gt_solvent, pred_solvent, alpha=0.5)
plt.plot(gt_solvent, solvent_regression_line, 'r--')
plt.errorbar(gt_solvent, pred_solvent, yerr=solvent_errors, fmt='o', alpha=0.5)
plt.xlabel('Ground Truth Solvent Energies')
plt.ylabel('Predicted Solvent Energies')
plt.title('Solvent: Predicted vs Ground Truth (MAE: {:.4f})'.format(np.mean(solvent_errors)))
plt.axis('equal')




plt.tight_layout()
plt.show()

In [None]:
vac_model.eval()
diffs = {}
with torch.no_grad():
    for key, values in vacuum_data.items():
        diffs[key] = []
        for x in values:
            predicted_y = vac_model(x.to('cuda'))
            diffs[key].append(predicted_y.item())

In [None]:
# plot the correlation between predicted and true values with labels of the molecule
fig, ax = plt.subplots(figsize=(10, 10))
for key, values in diffs.items():
    pred_mean = np.mean(values)
    true_mean = np.mean([x.y.item() for x in vacuum_data[key]])
    if np.sqrt((pred_mean-true_mean)**2) <= 0.3:
        continue
    ax.scatter([x*627.503 for x in values], [x.y.item()*627.503 for x in vacuum_data[key]], label=key)
#add diagonal 
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c="red")
ax.legend()
plt.show()