In [20]:
"""
Python threedconf.py

Description: This file contains the proposed basic architecture to generate 3-D
            conformers. The code is imported from https://github.com/divelab/DIG
            and edited in order to include Sphernet as embeddings layer.

Author: Abanoub Abdelmalak

Date Created: May 1, 2023

"""


import torch
from torch import nn
from torch.nn import Linear, Embedding
from torch_geometric.nn.inits import glorot_orthogonal
from torch_geometric.nn import radius_graph
from torch_scatter import scatter
from math import sqrt

from dig.threedgraph.utils import xyz_to_dat
from dig.threedgraph.method.spherenet.features import dist_emb, angle_emb, torsion_emb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def swish(x):
    return x * torch.sigmoid(x)

class emb(torch.nn.Module):
    def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent):
        super(emb, self).__init__()
        self.dist_emb = dist_emb(num_radial, cutoff, envelope_exponent)
        self.angle_emb = angle_emb(num_spherical, num_radial, cutoff, envelope_exponent)
        self.torsion_emb = torsion_emb(num_spherical, num_radial, cutoff, envelope_exponent)
        self.reset_parameters()

    def reset_parameters(self):
        self.dist_emb.reset_parameters()

    def forward(self, dist, angle, torsion, idx_kj):
        dist_emb = self.dist_emb(dist)
        angle_emb = self.angle_emb(dist, angle, idx_kj)
        torsion_emb = self.torsion_emb(dist, angle, torsion, idx_kj)
        return dist_emb, angle_emb, torsion_emb

class ResidualLayer(torch.nn.Module):
    def __init__(self, hidden_channels, act=swish):
        super(ResidualLayer, self).__init__()
        self.act = act
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin1.weight, scale=2.0)
        self.lin1.bias.data.fill_(0)
        glorot_orthogonal(self.lin2.weight, scale=2.0)
        self.lin2.bias.data.fill_(0)

    def forward(self, x):
        return x + self.act(self.lin2(self.act(self.lin1(x))))


class init(torch.nn.Module):
    def __init__(self, num_radial, hidden_channels, act=swish, use_node_features=True):
        super(init, self).__init__()
        self.act = act
        self.use_node_features = use_node_features
        if self.use_node_features:
            self.emb = Embedding(95, hidden_channels)
        else: # option to use no node features and a learned embedding vector for each node instead
            self.node_embedding = nn.Parameter(torch.empty((hidden_channels,)))
            nn.init.normal_(self.node_embedding)
        self.lin_rbf_0 = Linear(num_radial, hidden_channels)
        self.lin = Linear(3 * hidden_channels, hidden_channels)
        self.lin_rbf_1 = nn.Linear(num_radial, hidden_channels, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        if self.use_node_features:
            self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
        self.lin_rbf_0.reset_parameters()
        self.lin.reset_parameters()
        glorot_orthogonal(self.lin_rbf_1.weight, scale=2.0)

    def forward(self, x, emb, i, j):
        rbf,_,_ = emb
        if self.use_node_features:
            x = self.emb(x)
        else:
            x = self.node_embedding[None, :].expand(x.shape[0], -1)
        rbf0 = self.act(self.lin_rbf_0(rbf))
        e1 = self.act(self.lin(torch.cat([x[i], x[j], rbf0], dim=-1)))
        e2 = self.lin_rbf_1(rbf) * e1

        return e1, e2


class update_e(torch.nn.Module):
    def __init__(self, hidden_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, num_spherical, num_radial,
        num_before_skip, num_after_skip, act=swish):
        super(update_e, self).__init__()
        self.act = act
        self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size_dist, bias=False)
        self.lin_rbf2 = nn.Linear(basis_emb_size_dist, hidden_channels, bias=False)
        self.lin_sbf1 = nn.Linear(num_spherical * num_radial, basis_emb_size_angle, bias=False)
        self.lin_sbf2 = nn.Linear(basis_emb_size_angle, int_emb_size, bias=False)
        self.lin_t1 = nn.Linear(num_spherical * num_spherical * num_radial, basis_emb_size_torsion, bias=False)
        self.lin_t2 = nn.Linear(basis_emb_size_torsion, int_emb_size, bias=False)
        self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False)

        self.lin_kj = nn.Linear(hidden_channels, hidden_channels)
        self.lin_ji = nn.Linear(hidden_channels, hidden_channels)

        self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False)
        self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False)

        self.layers_before_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act)
            for _ in range(num_before_skip)
        ])
        self.lin = nn.Linear(hidden_channels, hidden_channels)
        self.layers_after_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act)
            for _ in range(num_after_skip)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)
        glorot_orthogonal(self.lin_t1.weight, scale=2.0)
        glorot_orthogonal(self.lin_t2.weight, scale=2.0)

        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)

        glorot_orthogonal(self.lin_down.weight, scale=2.0)
        glorot_orthogonal(self.lin_up.weight, scale=2.0)

        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_after_skip:
            res_layer.reset_parameters()

        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)

    def forward(self, x, emb, idx_kj, idx_ji):
        rbf0, sbf, t = emb
        x1,_ = x

        x_ji = self.act(self.lin_ji(x1))
        x_kj = self.act(self.lin_kj(x1))

        rbf = self.lin_rbf1(rbf0)
        rbf = self.lin_rbf2(rbf)
        x_kj = x_kj * rbf

        x_kj = self.act(self.lin_down(x_kj))

        sbf = self.lin_sbf1(sbf)
        sbf = self.lin_sbf2(sbf)
        x_kj = x_kj[idx_kj] * sbf

        t = self.lin_t1(t)
        t = self.lin_t2(t)
        x_kj = x_kj * t

        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x1.size(0))
        x_kj = self.act(self.lin_up(x_kj))

        e1 = x_ji + x_kj
        for layer in self.layers_before_skip:
            e1 = layer(e1)
        e1 = self.act(self.lin(e1)) + x1
        for layer in self.layers_after_skip:
            e1 = layer(e1)
        e2 = self.lin_rbf(rbf0) * e1

        return e1, e2


class update_v(torch.nn.Module):
    def __init__(self, hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_init):
        super(update_v, self).__init__()
        self.act = act
        self.output_init = output_init

        self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_output_layers):
            self.lins.append(nn.Linear(out_emb_channels, out_emb_channels))
        self.lin = nn.Linear(out_emb_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_up.weight, scale=2.0)
        for lin in self.lins:
            glorot_orthogonal(lin.weight, scale=2.0)
            lin.bias.data.fill_(0)
        if self.output_init == 'zeros':
            self.lin.weight.data.fill_(0)
        if self.output_init == 'GlorotOrthogonal':
            glorot_orthogonal(self.lin.weight, scale=2.0)

    def forward(self, e, i):
        _, e2 = e
        v = scatter(e2, i, dim=0)
        v = self.lin_up(v)
        for lin in self.lins:
            v = self.act(lin(v))
        v = self.lin(v)
        return v


class update_u(torch.nn.Module):
    def __init__(self):
        super(update_u, self).__init__()

    def forward(self, u, v, batch):
        u += scatter(v, batch, dim=0)
        return u


class SphereNet(torch.nn.Module):
    r"""
         The spherical message passing neural network SphereNet from the `"Spherical Message Passing for 3D Molecular Graphs" <https://openreview.net/forum?id=givsRXsOt9r>`_ paper.
        
        Args:
            energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
            cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`)
            num_layers (int, optional): Number of building blocks. (default: :obj:`4`)
            hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`)
            out_channels (int, optional): Size of each output sample. (default: :obj:`1`)
            int_emb_size (int, optional): Embedding size used for interaction triplets. (default: :obj:`64`)
            basis_emb_size_dist (int, optional): Embedding size used in the basis transformation of distance. (default: :obj:`8`)
            basis_emb_size_angle (int, optional): Embedding size used in the basis transformation of angle. (default: :obj:`8`)
            basis_emb_size_torsion (int, optional): Embedding size used in the basis transformation of torsion. (default: :obj:`8`)
            out_emb_channels (int, optional): Embedding size used for atoms in the output block. (default: :obj:`256`)
            num_spherical (int, optional): Number of spherical harmonics. (default: :obj:`7`)
            num_radial (int, optional): Number of radial basis functions. (default: :obj:`6`)
            envelop_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`)
            num_before_skip (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`)
            num_after_skip (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`2`)
            num_output_layers (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`)
            act: (function, optional): The activation funtion. (default: :obj:`swish`)
            output_init: (str, optional): The initialization fot the output. It could be :obj:`GlorotOrthogonal` and :obj:`zeros`. (default: :obj:`GlorotOrthogonal`)
            
    """
    def __init__(
        self, energy_and_force=False, cutoff=5.0, num_layers=4,
        hidden_channels=128, out_channels=1, int_emb_size=64,
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256,
        num_spherical=7, num_radial=6, envelope_exponent=5,
        num_before_skip=1, num_after_skip=2, num_output_layers=3,
        act=swish, output_init='GlorotOrthogonal', use_node_features=True):
        super(SphereNet, self).__init__()

        self.cutoff = cutoff
        self.energy_and_force = energy_and_force

        self.init_e = init(num_radial, hidden_channels, act, use_node_features=use_node_features)
        self.init_v = update_v(hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_init)
        self.init_u = update_u()
        self.emb = emb(num_spherical, num_radial, self.cutoff, envelope_exponent)

        self.update_vs = torch.nn.ModuleList([
            update_v(hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_init) for _ in range(num_layers)])

        self.update_es = torch.nn.ModuleList([
            update_e(hidden_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, num_spherical, num_radial, num_before_skip, num_after_skip,act) for _ in range(num_layers)])

        self.update_us = torch.nn.ModuleList([update_u() for _ in range(num_layers)])
        
        self.reset_parameters()

    def reset_parameters(self):
        self.init_e.reset_parameters()
        self.init_v.reset_parameters()
        self.emb.reset_parameters()
        for update_e in self.update_es:
            update_e.reset_parameters()
        for update_v in self.update_vs:
            update_v.reset_parameters()


    def forward(self, batch_data):
        z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
        if self.energy_and_force:
            pos.requires_grad_()
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
        num_nodes=z.size(0)
        dist, angle, torsion, i, j, idx_kj, idx_ji = xyz_to_dat(pos, edge_index, num_nodes, use_torsion=True)

        emb = self.emb(dist, angle, torsion, idx_kj)
        #print(batch_data.shape)
        #Initialize edge, node, graph features
        e = self.init_e(z, emb, i, j)
        v = self.init_v(e, i)
        u = self.init_u(torch.zeros_like(scatter(v, batch, dim=0)), v, batch) #scatter(v, batch, dim=0)

        for update_e, update_v, update_u in zip(self.update_es, self.update_vs, self.update_us):
            e = update_e(e, emb, idx_kj, idx_ji)
            v = update_v(e, i)
            u = update_u(u, v, batch) #u += scatter(v, batch, dim=0)

        return u, v, e


In [25]:
import torch
class SphernetWithTransformer(SphereNet):
    def __init__(self, energy_and_force=False, cutoff=5.0, num_layers=4, 
        hidden_channels=128, out_channels=1, int_emb_size=64, 
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 
        num_spherical=3, num_radial=6, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3, use_node_features=True):
        super(SphernetWithTransformer, self).__init__()

        self.cutoff = cutoff
        self.energy_and_force = energy_and_force

        self.spherenet = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4, 
        hidden_channels=128, out_channels=1, int_emb_size=64, 
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 
        num_spherical=3, num_radial=6, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3, use_node_features=True)

        self.transformer = nn.Transformer(d_model=256, nhead=8)

        self.linear_out = nn.Linear(256, 3)

    def forward(self, batch_data):
        batch_data.to(device)
        z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
        
        # Forward pass through the SphereNet
        u, z_pred, _ = self.spherenet(batch_data)

        # Encode the latent representation with the transformer
        z_enc = self.transformer(z_pred.unsqueeze(1))

        # Decode the latent representation with a linear layer
        pos_pred = self.linear_out(z_enc)

        return pos_pred, z_enc
def loss_fn(pos, pos_pred):
    """
    Computes the loss between the original positions and the predicted positions.

    Args:
        pos: The original positions.
        pos_pred: The predicted positions.

    Returns:
        The loss between the original positions and the predicted positions.
    """

    loss = torch.pow(pos - pos_pred, 2).mean()
    return loss

In [4]:
import importlib
import torch
from torch_geometric.datasets import QM9
import pickle
from rdkit import Chem
import torch_geometric.data.data
from torch_geometric.data import Dataset, Data, Batch, InMemoryDataset
from torch_geometric.loader import DataLoader
#from torch.utils.data import DataLoader
import os
import os.path as osp
import sys
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from torch_geometric.data.collate import collate
from tqdm import tqdm
import glob
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit import Chem
from torch.nn.functional import one_hot
from torch import scatter
import random
from rdkit.Chem.rdchem import ChiralType
import numpy as np
from torchvision import transforms
from torch import Tensor
#from dig.threedgraph.method import SphereNet
from dig.threedgraph import method
from dig.threedgraph.evaluation import ThreeDEvaluator
from dig.threedgraph.method import run
from qm9_dataset import qm9_data_geom

In [6]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

In [7]:
path_to_data = '../../../others_approaches/conformation_generation/GeoMol/data/QM9/'


In [8]:
qm9_set = qm9_data_geom(root= path_to_data)

y is included!!


In [9]:
print(qm9_set[2].y)

tensor([1.])


In [10]:
#target = 'boltzmann_weight' # mu, alpha, homo, lumo, gap, r2, zpve, U0, U, H, G, Cv
#qm9_set.data.y = qm9_set.data[target]

split_idx = qm9_set.get_idx_split(len(qm9_set.data.y)-2, train_size=300, valid_size=100, seed=42)

train_dataset, valid_dataset, test_dataset = qm9_set[split_idx['train']], qm9_set[split_idx['valid']], qm9_set[split_idx['test']]


In [30]:
import torch
import torch.optim as optim
def evaluation(pred, target):
    pred = pred.cpu().detach().numpy()
    target = target.cpu().detach().numpy()
    rmse = np.sqrt(np.mean((pred - target)**2))

    return rmse
def trainer(model, loss_func, train_dataset, valid_dataset, epochs=100):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        for batch in DataLoader(train_dataset, batch_size=4, shuffle=True):
            batch = batch.to(device)
            pred, z_pred = model(batch)
            loss = loss_func(pred, batch.pos)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Evaluate on validation set
        with torch.no_grad():
            pred, z_pred = model(valid_dataset.to(device))
            loss = loss_func(pred, valid_dataset.pos.to(device))
            rmse = evaluation(pred.cpu().detach().numpy(), valid_dataset.pos.cpu().detach().numpy())

        print(f"Epoch {epoch}: loss {loss.item():.4f}, RMSE {rmse:.4f}")
def generate_3d_positions(model, latent_representations):
    pred_positions = []
    for latent_representation in latent_representations:
        z_pred = model.autoencoder(latent_representation)
        pred_position = np.array(z_pred.cpu().detach().numpy())
        pred_positions.append(pred_position)

    return pred_positions

def train(model, dataloader, optimizer, loss_fn, num_epochs):
    """
    Trains the model on the given dataloader for the specified number of epochs.

    Args:
        model: The model to train.
        dataloader: The dataloader to use for training.
        optimizer: The optimizer to use for training.
        loss_fn: The loss function to use for training.
        num_epochs: The number of epochs to train for.
    """

    for epoch in range(num_epochs):
        for batch in dataloader:
            batch.to("cuda:0")
            pos, pos_pred = model(batch)
            loss = loss_fn(pos, pos_pred)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model


In [26]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = SphernetWithTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#loss_fn = loss_fn().to(device)



In [29]:
#train_dataset.to(device)
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

trained_model = train(model, dataloader, optimizer, loss_fn, 10)


AttributeError: 'GlobalStorage' object has no attribute 'shape'

In [28]:
train(model, loss_func, train_dataset, valid_dataset, epochs=5)

NameError: name 'loss_func' is not defined

In [None]:
print(qm9_data_geom[2].y)

In [9]:
model = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4, 
        hidden_channels=128, out_channels=1, int_emb_size=64, 
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 
        num_spherical=3, num_radial=6, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3, use_node_features=True)
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

In [None]:
run3d = run()
run3d.run(device, train_dataset, valid_dataset, test_dataset, 
        model, loss_func, evaluation, 
        epochs=2, batch_size=4, vt_batch_size=4, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=15)

In [None]:
print(train_dataset[0])

In [None]:
loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
for batch in loader:
    print(batch[0])

In [None]:
from dig.threedgraph.dataset import QM93D

In [None]:
dataset = QM93D(root='../../../others_approaches/embeddings_nets/DIG-dig-stable/tutorials/KDD2022/dataset/')
target = 'U0' # mu, alpha, homo, lumo, gap, r2, zpve, U0, U, H, G, Cv
dataset.data.y = dataset.data[target]

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=110000, valid_size=10000, seed=42)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))
print(dataset)

In [None]:
model = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4, 
        hidden_channels=128, out_channels=1, int_emb_size=64, 
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 
        num_spherical=3, num_radial=6, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3, use_node_features=True)
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

In [None]:
run3d = run()
run3d.run(device, train_dataset, valid_dataset, test_dataset, 
        model, loss_func, evaluation, 
        epochs=2, batch_size=4, vt_batch_size=4, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=15)