In [5]:
import torch
import torch.nn as nn
# from numpy import exp, sqrt
from numpy.random import normal
from torch import exp, sqrt, randn_like
from rdkit import Chem
import torch.nn.functional as F

In [2]:
LATENT_DIM = 4      # arbitrary
MAX_SIZE = 22*3     # try first with (x,y,z) for each atom

class MoleculeVAE(nn.Module):
    def __init__(self): # maybe also init encoder, decoder, in_channels, out_channels, depth, hidden size, dropout, gnn_type
        # self.decoder = DefaultDecoder() if decoder is None else decoder
        # maybe pass in data as reaction frame here too?
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(MAX_SIZE, 32),
            nn.ReLU(),
            nn.Linear(32, 2*LATENT_DIM)
        )
        self.decoder = nn.Sequential(
            nn.Linear(2*LATENT_DIM, 32),
            nn.ReLU(),
            nn.Linear(32, MAX_SIZE)
        )

        """
        self.convs = torch.nn.ModuleList()

        for _ in range(self.depth):
            if self.gnn_type == ?:
                self.convs.append(e.g. GCNConv(args))
        """
    
    def reparameterise(self, mean_z, log_var_z):
        if self.training:
            eps = randn_like() # is this meant to be randn_like(log_var_z)?
            z = mean_z + eps * sqrt(exp(log_var_z))
            return z
        else:
            return mean_z
    
    def forward(self, x): # note: might be easier to do this the other way explicitly defining mean and var
        # reshape input into a vector, then reshape using view(-1, batchsize=2, d)
        params_z = self.encoder(x.view(-1, input_size)).view(-1, 2, LATENT_DIM) # this encoder may need to be changed
        mean_z = params_z[:, 0, :]
        log_var_z = params_z[:, 1, :]
        z = self.reparameterise(mean_z, log_var_z)
        return self.decoder(z), mean_z, log_var_z 

    # x_hat = model TS; x = TS
    def loss_func(generated_TS, real_TS, mean_z, log_var_z, beta=1):
        BCE = nn.functional.binary_cross_entropy(generated_TS, real_TS.view(-1, MAX_SIZE), reduction='sum') 
        KLD = 0.5 * torch.sum(exp(log_var_z) - log_var_z - 1 + mean_z**2)
        return BCE + beta * KLD

# define device and model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MoleculeVAE().to(device)

In [10]:
rxn_dataset = ReactionDataset(r'data/')

  0%|          | 0/6739 [00:00<?, ?it/s]Processing...
100%|██████████| 6739/6739 [00:06<00:00, 1093.34it/s]
100%|██████████| 6739/6739 [00:06<00:00, 976.26it/s]
100%|██████████| 6739/6739 [00:06<00:00, 982.86it/s] 
100%|██████████| 842/842 [00:00<00:00, 1047.57it/s]
100%|██████████| 842/842 [00:00<00:00, 892.46it/s]
100%|██████████| 842/842 [00:00<00:00, 1050.84it/s]
Done!


In [255]:
rxn_dataset

ReactionDataset(22743)

In [9]:
from torch_scatter import scatter
from torch_geometric.data import InMemoryDataset, Data
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm

class ReactionDataset(InMemoryDataset):
    """ Creates instance of reaction dataset. """

    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

    def __init__(self, root, transform=None, pre_transform=None):
        super(ReactionDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['/raw/train_reactants.sdf', '/raw/train_ts.sdf', '/raw/train_products.sdf', '/raw/test_reactants.sdf', '/raw/test_ts.sdf', '/raw/test_products.sdf']
    
    @property
    def processed_file_names(self):
        """ If files already in processed folder, this processing is skipped. """
        return ['train_r.pt', 'train_ts.pt', 'train_p.pt', 'test_r.pt', 'test_ts.pt', 'test_p.pt']

    def download(self):
        """ Not required in this project. """
        pass

    def process(self):
        """ Processes each of the six geometry files and appends to a list. 
            Code mostly lifted from QM9 dataset creation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html 
        """

        for g_idx, geometry_file in enumerate(self.raw_file_names): # should maybe create enum with raw-processed together
            
            data_list = []
            full_path = self.root + geometry_file
            geometries = Chem.SDMolSupplier(full_path, removeHs=False, sanitize=False)
            
            for i, mol in enumerate(tqdm(geometries)):
                
                N = mol.GetNumAtoms()

                # get atom positions as matrix w shape [num_nodes, num_dimensions] = [num_atoms, 3]
                atom_data = geometries.GetItemText(i).split('\n')[4:4 + N] 
                atom_positions = [[float(x) for x in line.split()[:3]] for line in atom_data]
                atom_positions = torch.tensor(atom_positions, dtype=torch.float)

                # all the features
                type_idx = []
                atomic_number = []
                aromatic = []
                sp = []
                sp2 = []
                sp3 = []
                num_hs = []

                # atom/node features
                for atom in mol.GetAtoms():
                    type_idx.append(self.types[atom.GetSymbol()])
                    atomic_number.append(atom.GetAtomicNum())
                    aromatic.append(1 if atom.GetIsAromatic() else 0)
                    hybridisation = atom.GetHybridization()
                    sp.append(1 if hybridisation == HybridizationType.SP else 0)
                    sp2.append(1 if hybridisation == HybridizationType.SP2 else 0)
                    sp3.append(1 if hybridisation == HybridizationType.SP3 else 0)

                # bond/edge features
                row, col, edge_type = [], [], []
                for bond in mol.GetBonds(): 
                    start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                    row += [start, end]
                    col += [end, start]
                    # edge type for each bond type; *2 because both ways
                    edge_type += 2 * [self.bonds[bond.GetBondType()]]

                # edge_index is graph connectivity in COO format with shape [2, num_edges]
                edge_index = torch.tensor([row, col], dtype=torch.long) 
                edge_type = torch.tensor(edge_type, dtype=torch.long)
                # one hot the edge types into distinct types for bonds
                # edge_attr is edge feature matrix with shape [num_edges, num_edge_features]
                edge_attr = F.one_hot(edge_type, num_classes=len(self.bonds)).to(torch.float) 

                # order edges based on combined ascending order
                perm = (edge_index[0] * N + edge_index[1]).argsort()
                edge_index = edge_index[:, perm]
                edge_type = edge_type[perm]
                edge_attr = edge_attr[perm]

                row, col = edge_index
                z = torch.tensor(atomic_number, dtype=torch.long)
                hs = (z == 1).to(torch.float) # hydrogens
                # https://abderhasan.medium.com/pytorchs-scatter-function-a-visual-explanation-351d25c05c73
                # helps with one-hot encoding, should come back to this
                num_hs = scatter(hs[row], col, dim_size=N).tolist() 
                
                x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.types))
                x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()
                x = torch.cat([x1.to(torch.float), x2], dim=-1)

                # no direct y since plan to decode to TS
                data = Data(x=x, z=z, pos=atom_positions, edge_index=edge_index, edge_attr=edge_attr, idx=i)

                data_list.append(data)

                # if self.pre_filter is not None and not self.pre_filter(data):
                #     continue
                # if self.pre_transform is not None:
                #     data = self.pre_transform(data)

            torch.save(self.collate(data_list), self.processed_paths[g_idx]) 

In [20]:
class ReactionDataset(torch.utils.data.Dataset):
    """Creates instance of reaction dataset. Has functions for train and test sets."""
    
    def __init__(self, train_r_name, train_ts_name, train_p_name, 
                       test_r_name,  test_ts_name,  test_p_name, base_folder='data/'):
        self.base_folder = base_folder
        train_r_file = base_folder + train_r_name
        train_ts_file = base_folder + train_ts_name
        train_p_file = base_folder + train_p_name
        test_r_file = base_folder + test_r_name
        test_ts_file = base_folder + test_ts_name
        test_p_file = base_folder + test_p_name
        self.train_r, self.train_ts, self.train_p = create_training_set(train_r_file, train_ts_file, train_p_file)
        self.test_r, self.test_ts, self.test_p = create_test_set(test_r_file, test_ts_file, test_p_file)

        # sdf to rdmol to PTG data

    def create_training_set(self, train_r_file, train_ts_file, train_p_file):
        train_r = sdf_to_rdmol(train_r_file)
        train_ts = sdf_to_rdmol(train_ts_file)
        train_p = sdf_to_rdmol(train_p_file)
        return train_r, train_ts, train_p
    
    def create_test_set(self, test_r_file, test_ts_file, test_p_file):
        test_r = sdf_to_rdmol(test_r_file)
        test_ts = sdf_to_rdmol(test_ts_file)
        test_p = sdf_to_rdmol(test_p_file)
        return test_r, test_ts, test_p
    
    def sdf_to_rdmol(self, geometry_file):
        geometries = Chem.ForwardSDMolSupplier(geometry_file, removeHs=False, sanitize=False)
        geometries = [x for x in geometries]
    
    def process_rdmol(self, geometry_file):
        geometries = Chem.SDMolSupplier(geometry_file, removeHs=False, sanitize=False)


    def coordinate_to_interatomic_dist():
        # maybe generalise this function with a flag for different initial inputs
            # e.g. interatomic distances, Z-matrix, nuclear charge, etc.
        # if these are different enough, may be better to have this instance as an abstract class then have implementations for each matrix type
        return

    def visualise_feature_dynamics()
        # may need a function here to calculate reaction centre
        # visualise how features change over time e.g. interatomic distances
        # how much are interatomic distances changing? what precision of model do we need to capture these differences?
        return

    def dataset_properties():
        # functions for rdkit molecule properties to compare
        return

    # other funcs: scaffold bias on train-test split - how would that work in 3D?

class ModelRun():
    def __init__(self, training_rxns, test_rxns, model):
        self.training_rxns = training_rxns
        self.test_rxns = test_rxns

    # plot loss 
    # plotting evaluation here
    
    def preprocess_data():
        # preprocess the data for each model as it suits

        return
    
    # compare model TS: average, initial guesses, final estimates, reals
        # will allow me to compare several models against each other
    
    # will have funcs to evaluate structure of output
        # for generative models: evaluate structure of latent space
        # evaluate how R and P are combined

    # how are TS formed?



In [9]:
### put R-TS-P into dataloader

# really, it's train_r, train_p, val_ts then test_r, test_ts, test_p

# zip rs and ps
train_rxn_endpoints = list(zip(train_r, train_p))


# class ReactionDataset()


In [13]:
def create_dataloader(args):

    if isinstance(modes, str):
        modes = [modes]


def construct_loader(args, modes=('train', 'val')):

    if isinstance(modes, str):
        modes = [modes]

    data_df = pd.read_csv(args.data_path)

    smiles = data_df.iloc[:, 0].values
    labels = data_df.iloc[:, 1].values.astype(np.float32)

    loaders = []
    for mode in modes:
        dataset = MolDataset(smiles, labels, args, mode)
        loader = DataLoader(dataset=dataset,
                            batch_size=args.batch_size,
                            shuffle=not args.no_shuffle if mode == 'train' else False,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            sampler=StereoSampler(dataset) if args.shuffle_pairs else None)
        loaders.append(loader)

    if len(loaders) == 1:
        return loaders[0]
    else:
        return loaders

list

In [13]:
# setting optimiser
learning_rate = 1e-3
optimiser = torch.optim.Adam(model.parameters(), lr = learning_rate)

# training and testing the VAE
epochs = 5
codes = dict(mean=list(), log_var=list(), y=list())
for epoch in range(0, epochs+1):
    # training
    if epoch > 0:
        model.train()
        train_loss = 0
        for real_TS, _ in train_loader:
            real_TS = real_TS.to(device)
            # === forward ===
            generated_TS, mean_z, log_var_z = model(x)
            loss = loss_func(generated_TS, real_TS, mean_z, log_var_z)
            train_loss += loss.item()
            # === backward ===
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        # === log ===
        print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
        
    # testing
    means, log_vars, labels = list(), list(), list()
    with torch.no_grad():
        model.eval()
        test_loss = 0
        for x, y in test_loader:
            x = x.to(device)
            # === forward ===
            x_hat, mean, log_var = model(x)
            test_loss += loss_function(x_hat, x, mean, log_var).item()
            # === log ===
            means.append(mean.detach())
            log_vars.append(log_var.detach())
            labels.append(y.detach())
    # === log ===
    codes['mean'].append(torch.cat(means))
    codes['log_var'].append(torch.cat(log_vars))
    codes['y'].append(torch.cat(labels))
    test_loss /= len(test_loader.dataset)
    print(f'===> Test set loss: {test_loss:.4f}')
    display_images(x, x_hat, 1, f'Epoch {epoch}')

NameError: name 'test_loader' is not defined

In [None]:
class VAE(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # linear (size of input, 2d), size of input= max possible size i.e. largest mol
            nn.Linear(input_size, d**2),
            nn.ReLU(),
            nn.Linear(d ** 2, d * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(d, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, input_size)
            # would use sigmoid here if input was between 0 and 1
        )

    def reparameterise(self, mean_z, log_var_z):
        if self.training:
            # eps = normal(loc=0, scale=1, size=(len(graphs.nodes), self.latent_dim=2d))
            # since variances only positive, computing log allows you to output full real range for encoder
            eps = normal(0, 1, size=(len(input_nodes), latent_dims=2d))
            z = mean_z + eps * sqrt(exp(log_var_z))
            return z
        else:
            return mean_z

    def forward(self, x):
        # reshape input into a vector, then reshape using view(-1, batchsize=2, d)
        params_z = self.encoder(x.view(-1, input_size)).view(-1, 2, d)
        
        mean_z = params_z[:, 0, :]
        log_var_z = params_z[:, 1, :]
        z = self.reparameterise(mean_z, log_var_z)
        return self.decoder(z), mean_z, log_var_z 

model = VAE().to(device)

In [None]:
# setting optimiser
learning_rate = 1e-3
optimiser = torch.optim.Adam(model.parameters(), lr = learning_rate)

# reconstruction + KL divergence losses summed over all elements
def loss_function(x_hat, x, mean_z, log_var_z, beta):
    # binary cross entropy between input and reconstruction
    BCE = nn.functional.binary_cross_entropy(x_hat, x.view(-1, 784), reduction='sum') 
    # kl divergence: var is linear, - log var is logarithmic, mean is squared 
    KLD = 0.5 * torch.sum(exp(log_var_z) - log_var_z - 1 + mean_z**2)
    return BCE + beta * KLD

In [None]:
# training and testing the VAE
epochs = 5
codes = dict(mean=list(), log_var=list(), y=list())
for epoch in range(0, epochs+1):
    # training
    if epoch > 0:
        model.train()
        train_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            # === forward ===
            x_hat, mean, log_var = model(x)
            loss = loss_function(x_hat, x, mean, log_var)
            train_loss += loss.item()
            # === backward ===
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        # === log ===
        print(f'====> Epoch: {epoch} Average loss: {
            train_loss / len(train_loader.dataset):.4f}')
        
    # testing
    means, log_vars, labels = list(), list(), list()
    with torch.no_grad():
        model.eval()
        test_loss = 0
        for x, y in test_loader:
            x = x.to(device)
            # === forward ===
            x_hat, mean, log_var = model(x)
            test_loss += loss_function(x_hat, x, mean, log_var).item()
            # === log ===
            means.append(mean.detach())
            log_vars.append(log_var.detach())
            labels.append(y.detach())
    # === log ===
    codes['mean'].append(torch.cat(means))
    codes['log_var'].append(torch.cat(log_vars))
    codes['y'].append(torch.cat(labels))
    test_loss /= len(test_loader.dataset)
    print(f'===> Test set loss: {test_loss:.4f}')
    display_images(x, x_hat, 1, f'Epoch {epoch}')


In [None]:
# generating a few samples
N = 16
z = torch.randn((N, d)).to(device)
sample = model.decoder(z)
display_images(None, sample, N//4, count=True)

# Choose starting and ending point for the interpolation -> shows original and reconstructed

A, B = 1, 14
sample = model.decoder(torch.stack((mean[A].data, mean[B].data), 0))
display_images(None, torch.stack(((
    x[A].data.view(-1),
    x[B].data.view(-1),
    sample.data[0],
    sample.data[1]
)), 0))