In [1]:
import abc
import logging
import ssl
import sys
import urllib

from multiprocess.pool import Pool
from torch_geometric.data.data import BaseData
sys.path.append("../../")

In [2]:
import torch
import torch_geometric as pyg
from ptgnn.features.chiro.embedding_functions import embedConformerWithAllPaths

In [3]:
from ptgnn.transform import PRE_TRANSFORM_MAPPING
from ptgnn.masking import MASKING_MAPPING
import os

In [4]:
from pathlib import Path

In [5]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

In [6]:
import pickle
from tqdm import tqdm
import pandas as pd

In [7]:
from ptgnn.dataset.utils_chienn import download_url_to_path

In [8]:
from ptgnn.features.chienn.molecule3d import smiles_to_3d_mol
from ptgnn.dataset.utils_chienn import get_chiro_data_from_mol

In [11]:
class RSDataset(pyg.data.InMemoryDataset): # potentially change to inmemory dataset
    """
    Dataset adapted from ChiENN/GraphGPS: https://github.com/gmum/ChiENN/blob/master/experiments/graphgps/dataset/rs_dataset.py
    """

    def __init__(
            self,
            root: str = "src",
            single_conformer: bool = True,
            mask_chiral_tags: bool = False,
            split: str = "train",
            graph_mode: str = "edge",
            max_atoms: int = 100,
            max_attempts: int = 100 # significantly decreased - 5000 is way too much!
    ):
        """
        Init of the RS dataset class

        :param root: Path to which the dataset should be saved
        """
        # link storage
        self.link_storage = {
            'train': 'https://figshare.com/ndownloader/files/30975694?private_link=e23be65a884ce7fc8543',
            'val': 'https://figshare.com/ndownloader/files/30975703?private_link=e23be65a884ce7fc8543',
            'test': 'https://figshare.com/ndownloader/files/30975679?private_link=e23be65a884ce7fc8543'
        }

        # set internal parameters
        self.single_conformer = single_conformer
        self.mask_chiral_tags = mask_chiral_tags
        self.split = split
        self.graph_mode = graph_mode
        self.pre_transform = PRE_TRANSFORM_MAPPING.get(self.graph_mode)
        self.masking = MASKING_MAPPING.get(self.graph_mode)
        self.max_atoms = max_atoms
        self.max_attempts = max_attempts

        super().__init__(
            root=root,
            transform=None,
            pre_transform=self.pre_transform,
            pre_filter=None
        )
        self.dataframe = pd.read_csv(os.path.join(self.processed_dir, f'{split}.csv'))
        self.data, self.slices = torch.load(os.path.join(self.processed_dir, f"{split}.pt"))

    def __getitem__(self, item):
        data = super().__getitem__(item)
        if self.mask_chiral_tags:
            data = self.masking(data)
        return data

    @property
    def raw_file_names(self):
        return ['train.pickle', 'val.pickle', 'test.pickle']

    @property
    def processed_dir(self) -> str:
        name = 'single_conformer' if self.single_conformer else 'all_conformers'
        graph_mode = self.graph_mode if self.graph_mode else ''
        return os.path.join(self.root, name, graph_mode, 'processed')

    @property
    def processed_file_names(self):
        return ['val.pt', 'val.csv']# ['train.pt', 'val.pt', 'test.pt', 'train.csv', 'val.csv', 'test.csv']

    def download(self):
        for split, link in self.link_storage.items():
            split_pickle_path = os.path.join(self.raw_dir, f'{split}.pickle')
            download_url_to_path(link, split_pickle_path)

    def process(self):
        """
        Processes and saves datapoints from the entire dataset. It additionally saves original dataframes from
        downloaded pickles which are then used in `SingleConformerBatchSampler` in `get_custom_loader`.
        """
        for split in ['val']:# ['train', 'val', 'test']:
            with open(os.path.join(self.raw_dir, f'{split}.pickle'), 'rb') as f:
                split_df = pickle.load(f)

            if self.single_conformer:
                split_df = split_df.drop_duplicates(subset='ID')

            data_list = []
            omitted = 0
            to_remove = set()
            for index, row in tqdm(split_df.iterrows(), desc=f'Processing {split} dataset', total=len(split_df)):
                smiles_nonstereo = row['SMILES_nostereo']
                if smiles_nonstereo in to_remove:
                    omitted += 1
                    continue

                smiles = row['ID']
                mol = smiles_to_3d_mol(smiles, max_number_of_atoms=self.max_atoms, max_number_of_attempts=self.max_attempts)
                if mol is None:
                    omitted += 1
                    to_remove.add(smiles_nonstereo)
                    continue
                try:
                    data = get_chiro_data_from_mol(mol)
                except Exception as e:
                    omitted += 1
                    to_remove.add(smiles_nonstereo)
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data.y = torch.tensor(row['RS_label_binary']).long()
                data_list.append(data)

            torch.save(self.collate(data_list),
                       os.path.join(self.processed_dir, f'{split}.pt'))
            split_df = split_df.drop(columns='rdkit_mol_cistrans_stereo')
            split_df = split_df[~split_df['SMILES_nostereo'].isin(to_remove)]
            split_df.to_csv(os.path.join(self.processed_dir, f'{split}.csv'), index=None)

    def len(self) -> int:
        return super().__len__()

    def get(self, idx: int) -> BaseData:
        return super().__getitem__(idx)

In [12]:
data = RSDataset(split="val")

In [1]:
data[0]

NameError: name 'data' is not defined

In [9]:
import torch_geometric as pyg
import torch
import os
from typing import Union, List, Tuple

In [21]:
class Test(pyg.data.InMemoryDataset):
    def __init__(
            self,
            root: str = "src",
            single_conformer: bool = True,
            mask_chiral_tags: bool = False,
            split: str = "train",
            graph_mode: str = "edge",
            max_atoms: int = 100,
            max_attempts: int = 100 # significantly decreased - 5000 is way too much!
    ):
        """
        Init of the RS dataset class

        :param root: Path to which the dataset should be saved
        """
        # link storage
        self.link_storage = {
            'train': 'https://figshare.com/ndownloader/files/30975694?private_link=e23be65a884ce7fc8543',
            'val': 'https://figshare.com/ndownloader/files/30975703?private_link=e23be65a884ce7fc8543',
            'test': 'https://figshare.com/ndownloader/files/30975679?private_link=e23be65a884ce7fc8543'
        }

        # set internal parameters
        self.single_conformer = single_conformer
        self.mask_chiral_tags = mask_chiral_tags
        self.split = split
        self.graph_mode = graph_mode
        self.pre_transform = PRE_TRANSFORM_MAPPING.get(self.graph_mode)
        self.masking = MASKING_MAPPING.get(self.graph_mode)
        self.max_atoms = max_atoms
        self.max_attempts = max_attempts

        super().__init__(
            root=root,
            transform=None,
            pre_transform=self.pre_transform,
            pre_filter=None
        )

        self.data, self.slices = torch.load(os.path.join(self.processed_dir, f"{split}.pt"))

    @property
    def raw_file_names(self):
        return [f'{self.split}.pickle']

    @property
    def processed_dir(self) -> str:
        name = 'single_conformer' if self.single_conformer else 'all_conformers'
        graph_mode = self.graph_mode if self.graph_mode else ''
        return os.path.join(self.root, name, graph_mode, 'processed')

    @property
    def processed_file_names(self):
        return [f'{self.split}.pt', f'{self.split}.csv']

    def download(self):
        for split, link in self.link_storage.items():
            split_pickle_path = os.path.join(self.raw_dir, f'{split}.pickle')
            download_url_to_path(link, split_pickle_path)

    def process(self):
        # load downloaded data
        with open(os.path.join(self.raw_dir, f'{self.split}.pickle'), 'rb') as f:
            split_df = pickle.load(f)

        if self.single_conformer:
            split_df = split_df.drop_duplicates(subset="ID")

        data_list = []
        to_remove = set()

        # iterate over dataframe
        for index, row in tqdm(
                split_df.iterrows(),
                desc=f"Processing {self.split} dataset",
                total=len(split_df)
        ):

            # get nonstereo smiles string
            smiles_nonstereo = row["SMILES_nostereo"]

            # check if need to be skipped because in list
            if smiles_nonstereo in to_remove:
                continue

            # get the normal smiles
            smiles = row['ID']
            # get the molecule
            mol = smiles_to_3d_mol(
                smiles,
                max_number_of_attempts=self.max_attempts,
                max_number_of_atoms=self.max_atoms
            )

            # check if mol present
            if mol is None:
                to_remove.add(smiles_nonstereo)
                continue

            # attempt to generate data object (raw)
            try:
                data = get_chiro_data_from_mol(mol)
            except Exception as e:
                logging.warning(f"Omitting molecule {smiles} as cannot be properly embedded. The original error message was: {e}.")
                continue

            # do transformation
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            # set label and append
            data.y = torch.tensor(row['RS_label_binary']).long()
            data_list.append(data)

        # save processed data
        torch.save(
            self.collate(data_list),
            os.path.join(self.processed_dir, f"{self.split}.pt")
        )
        split_df = split_df.drop(columns="rdkit_mol_cistrans_stereo")
        split_df[~split_df['SMILES_nostereo'].isin(to_remove)]
        split_df.to_csv(os.path.join(self.processed_dir, f"{self.split}.csv"), index=None)

In [27]:
class Test(pyg.data.InMemoryDataset):
    def __init__(
            self,
            root: str = "src",
            single_conformer: bool = True,
            mask_chiral_tags: bool = False,
            split: str = "train",
            graph_mode: str = "edge",
            max_atoms: int = 100,
            max_attempts: int = 100 # significantly decreased - 5000 is way too much!
    ):
        """
        Init of the RS dataset class

        :param root: Path to which the dataset should be saved
        """
        # link storage
        self.link_storage = {
            'train': 'https://figshare.com/ndownloader/files/30975694?private_link=e23be65a884ce7fc8543',
            'val': 'https://figshare.com/ndownloader/files/30975703?private_link=e23be65a884ce7fc8543',
            'test': 'https://figshare.com/ndownloader/files/30975679?private_link=e23be65a884ce7fc8543'
        }

        # set internal parameters
        self.single_conformer = single_conformer
        self.mask_chiral_tags = mask_chiral_tags
        self.split = split
        self.graph_mode = graph_mode
        self.pre_transform = PRE_TRANSFORM_MAPPING.get(self.graph_mode)
        self.masking = MASKING_MAPPING.get(self.graph_mode)
        self.max_atoms = max_atoms
        self.max_attempts = max_attempts

        super().__init__(
            root=root,
            transform=None,
            pre_transform=self.pre_transform,
            pre_filter=None
        )

        self.data, self.slices = torch.load(os.path.join(self.processed_dir, f"{split}.pt"))

    @property
    def raw_file_names(self):
        return [f'{self.split}.pickle']

    @property
    def processed_dir(self) -> str:
        name = 'single_conformer' if self.single_conformer else 'all_conformers'
        graph_mode = self.graph_mode if self.graph_mode else ''
        return os.path.join(self.root, name, graph_mode, 'processed')

    @property
    def processed_file_names(self):
        return [f'{self.split}.pt', f'{self.split}.csv']

    def download(self):
        for split, link in self.link_storage.items():
            split_pickle_path = os.path.join(self.raw_dir, f'{split}.pickle')
            download_url_to_path(link, split_pickle_path)

    def process(self):
        # load downloaded data
        with open(os.path.join(self.raw_dir, f'{self.split}.pickle'), 'rb') as f:
            split_df = pickle.load(f)

        if self.single_conformer:
            split_df = split_df.drop_duplicates(subset="ID")

        # iterate over dataframe
        def worker(entry):
            from ptgnn.features.chienn.molecule3d import smiles_to_3d_mol
            from ptgnn.dataset.utils_chienn import get_chiro_data_from_mol
            import logging
            import torch

            index, row = entry

            # get nonstereo smiles string
            smiles_nonstereo = row["SMILES_nostereo"]

            # get the normal smiles
            smiles = row['ID']
            # get the molecule
            mol = smiles_to_3d_mol(
                smiles,
                max_number_of_attempts=self.max_attempts,
                max_number_of_atoms=self.max_atoms
            )

            # check if mol present
            if mol is None:
                return smiles_nonstereo, None

            # attempt to generate data object (raw)
            try:
                data = get_chiro_data_from_mol(mol)
            except Exception as e:
                logging.warning(f"Omitting molecule {smiles} as cannot be properly embedded. The original error message was: {e}.")
                return smiles_nonstereo, None

            # do transformation
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            # set label and append
            data.y = torch.tensor(row['RS_label_binary']).long()

            return smiles_nonstereo, data

        with Pool(processes=os.cpu_count()) as p:
            data_list = list(p.imap(worker, tqdm(split_df.iterrows())))

        display(data_list)

        # todo: seperate stuff and check for to_remove stuff
        to_remove = set([
            smiles_entry
            for smiles_entry, indicator in data_list
            if indicator is None
        ])
        display(to_remove)
        data_list = [
            data_object
            for smiles, data_object in data_list
            if data_object is not None and smiles not in to_remove
        ]
        display(data_list)

        # save processed data
        torch.save(
            self.collate(data_list),
            os.path.join(self.processed_dir, f"{self.split}.pt")
        )
        split_df = split_df.drop(columns="rdkit_mol_cistrans_stereo")
        split_df[~split_df['SMILES_nostereo'].isin(to_remove)]
        split_df.to_csv(os.path.join(self.processed_dir, f"{self.split}.csv"), index=None)

In [28]:
t1 = Test(split='val')

In [29]:
t1[0]

Data(x=[34, 52], edge_index=[2, 72], edge_attr=[72, 14], pos=[34, 3], bond_distances=[36], bond_distance_index=[2, 36], bond_angles=[60], bond_angle_index=[3, 60], dihedral_angles=[91], dihedral_angle_index=[4, 91], y=[1])

In [30]:
t1

Test(11740)

In [31]:
len(t1)

11740

In [32]:
t2 = Test(split="test")

In [33]:
t2[0]

Data(x=[17, 52], edge_index=[2, 36], edge_attr=[36, 14], pos=[17, 3], bond_distances=[18], bond_distance_index=[2, 18], bond_angles=[34], bond_angle_index=[3, 34], dihedral_angles=[49], dihedral_angle_index=[4, 49], y=[1])

In [34]:
len(t2)

11677

In [None]:
t3 = Test(split="train")

Processing...
34524it [03:09, 200.66it/s]