# Converting  Cambridge Structural Database (CSD) to compressed SDF files

Before running this notebook, you'll need to:

1. **Obtain a CSD license** - Contact the Cambridge Crystallographic Data Centre (CCDC) to get access to the Cambridge Structural Database.

2. **Download the CSD data** - After getting your license, download the database files from the CCDC portal.

3. **Setup the environment** - Create a conda environment using the provided environment file:
    ```bash
    conda env create -f env/csd-api-env.yml
    conda activate csd-api-env
    ```

This notebook is used to transform the CSD stored as sqlite to compressed sdf files

In [1]:
import functools
import gzip
import multiprocessing as mp
import shutil
import signal
import sys
import time
from functools import partial
from io import StringIO
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq
from ccdc.entry import Entry
from ccdc.io import EntryReader
from rdkit import Chem
from rdkit.Chem import SaltRemover

In [2]:
def compress2gzip(input_file: str | Path, output_file: str | Path, remove_original: bool = False) -> None:
    """
    Compress a file to gzip format.

    Parameters
    ----------
    input_file : str or Path
        Path to the input file that will be compressed.
    output_file : str or Path
        Path to the output compressed gzip file.
    remove_original : bool, optional
        If True, the original file will be removed after compression. Default is False.

    Returns
    -------
    None
        Function compresses a file but doesn't return any value.

    Raises
    ------
    TypeError
        If input_file or output_file are not str or Path objects.

    Notes
    -----
    Uses gzip and shutil to efficiently compress the input file.
    The original file is preserved unless remove_original is set to True.
    """
    # check if inputs are str or Path
    if not (isinstance(input_file, (str, Path)) and isinstance(output_file, (str, Path))):
        raise TypeError("input_file and output_file must be str or Path")

    # convert str to Path if needs be
    input_path = Path(input_file) if isinstance(input_file, str) else input_file
    output_path = Path(output_file) if isinstance(output_file, str) else output_file

    # compress file
    with input_path.open("rb") as f_in:
        with gzip.open(output_path, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)

    # remove original file if requested
    if remove_original:
        input_path.unlink()

In [3]:
# Create a PyArrow schema for error logs
error_log_schema = pa.schema(
    [
        pa.field("entry_index", pa.int32()),
        pa.field("identifier", pa.string()),
        pa.field("error_log", pa.string()),
        pa.field("processing_time", pa.float32()),
    ],
    metadata={
        "description": "Schema for logging errors encountered during CSD to SDF conversion",
        "entry_index": "Index of the CSD entry",
        "error_log": "Error message",
        "processing_time": "Time taken to process the entry",
    },
)

In [4]:
# takes 2 min and 15 sec roughly
# Answer is : 1394755
# counter=0

# for _ in csd_reader.entries():
#     counter+=1
# print(counter)
# divmod(counter,10_000)

In [5]:
# divmod(1394755, 10_000)  # (139, 755)

In [6]:
def timeout(max_timeout, default=None):
    """Timeout decorator, parameter in seconds (supports float values)."""

    def timeout_decorator(func):
        """Wrap the original function."""

        @functools.wraps(func)
        def func_wrapper(*args, **kwargs):
            """Timeout using signal with float precision."""

            class MyTimeoutError(Exception):
                pass

            def handler(signum, frame):
                raise MyTimeoutError

            # Set the timeout handler
            signal.signal(signal.SIGALRM, handler)
            # Use setitimer instead of alarm for float support
            signal.setitimer(signal.ITIMER_REAL, max_timeout)
            result = default

            try:
                result = func(*args, **kwargs)
            except MyTimeoutError as exc:
                # Handle the timeout
                print(str(exc))
            finally:
                # Cancel the timer
                signal.setitimer(signal.ITIMER_REAL, 0)

            return result

        return func_wrapper

    return timeout_decorator

In [7]:
@timeout(0.2, default=(None, "timeout"))
def process_entry(entry: Entry, salt_remover: SaltRemover.SaltRemover = None) -> tuple[Chem.Mol | None, str]:
    """
    Process a CSD entry into an RDKit molecule.

    This function converts a CSD entry to an RDKit molecule and performs salt removal.

    Parameters
    ----------
    entry : Entry
        The CSD database entry to process
    salt_remover : SaltRemover.SaltRemover, optional
        SaltRemover instance to strip salts from molecules. If None, a new instance
        will be created.

    Returns
    -------
    tuple[Chem.Mol | None, str | None]
        A tuple containing:
        - RDKit molecule if successfully processed, None otherwise
        - Error message if an error occurred, None otherwise

    Notes
    -----
    Error handling includes:
    - Capturing stderr during molecule conversion
    - Checking for empty molecules
    - Handling exceptions during salt removal
    """

    # instantiate variables
    rdkit_mol = None
    error_log = "no error"
    if salt_remover is None or not isinstance(salt_remover, SaltRemover.SaltRemover):
        salt_remover = SaltRemover.SaltRemover()
    mol = entry.molecule

    # try to convert the molecule to RDKit format
    sio = sys.stderr = StringIO()
    mol_block = mol.to_string("sdf")
    rdkit_mol = Chem.MolFromMolBlock(mol_block)
    error_log = sio.getvalue()

    # if error during conversion
    if error_log:
        return None, error_log

    # if empty molecule
    if rdkit_mol is None or rdkit_mol.GetNumAtoms() == 0:
        return None, "empty"

    # remove salts
    try:
        rdkit_mol = salt_remover.StripMol(rdkit_mol, dontRemoveEverything=True)
    except Exception as e:
        error_log = f"Error removing salt: {e}"
        return None, error_log

    # add entry name to molecule
    rdkit_mol.SetProp("CSD_Entry_Name", entry.identifier)

    return rdkit_mol, error_log


def process_batch(
    batch_idx: int, output_dir: Path, error_log_schema: pa.Schema, batch_size: int = 10_000, database_size: int = 1394755
) -> int:
    """
    Process a batch of CSD entries into RDKit molecules and save them as SDF.

    This function processes a batch of entries from the CSD database starting at
    the specified batch index, converts them to RDKit molecules, removes salts,
    and saves them to the output directory.

    Parameters
    ----------
    batch_idx : int
        Index of the batch to process
    output_dir : Path
        Directory where output SDF files and error logs will be saved
    batch_size : int, optional
        Number of entries to process in each batch. Defaults to 10_000.
    database_size : int, optional
        Total number of entries in the database. Defaults to 1394755.

    Returns
    -------
    int
        The index of the processed batch.

    Notes
    -----
    Error logs are collected for entries that fail to process and are written
    to a separate file for debugging purposes.
    """

    # instantiate variables
    error_logs = []
    entry_reader = EntryReader("CSD")
    salt_remover = SaltRemover.SaltRemover()

    # make sure output dir exists
    output_dir.mkdir(parents=True, exist_ok=True)

    # calculate the number of entries to process in this batch
    start_idx = batch_idx * batch_size
    end_idx = min(start_idx + batch_size, database_size)

    # process each entry in the batch
    with Chem.SDWriter(output_dir / f"CSD_batch_{batch_idx:03d}.sdf") as sdf_writer:
        for i in range(start_idx, end_idx):
            entry = entry_reader[i]
            start_time = time.perf_counter()
            rdkit_mol, error_log = process_entry(entry, salt_remover)
            if rdkit_mol:
                try:
                    sdf_writer.write(rdkit_mol)
                except Exception as e:
                    error_log = f"Error writing molecule to SDF: {e}"

            error_logs.append(
                {
                    "entry_index": i,
                    "identifier": entry.identifier,
                    "error_log": error_log,
                    "processing_time": time.perf_counter() - start_time,
                }
            )

    # write error to files
    if error_logs:
        error_log_file = output_dir / f"CSD_batch_{batch_idx:03d}_errors.parquet"
        table = pa.Table.from_pylist(error_logs, schema=error_log_schema)
        pq.write_table(table, error_log_file)

    # housekeeping
    entry_reader.close()

    # sdf file compression
    compress2gzip(output_dir / f"CSD_batch_{batch_idx:03d}.sdf", output_dir / f"CSD_batch_{batch_idx:03d}.sdf.gz", remove_original=True)
    return batch_idx

In [None]:
# define partial function
task_counter = 0
partial_process_batch = partial(
    process_batch,
    output_dir=Path("../data/processed/csd_sdf_batches"),
    error_log_schema=error_log_schema,
    batch_size=1_000,
    database_size=1394755,
)
# multiprocessing
with mp.Pool(processes=mp.cpu_count() - 2) as pool:
    results = pool.map(partial_process_batch, range(1400))

In [11]:
results_dir = Path("../data/processed/csd_sdf_batches/")

In [12]:
batch_idx_set = set()
for file in results_dir.glob("*.sdf.gz"):
    # retrieve index from filename
    batch_idx = int(file.stem.replace(".sdf", "").split("_")[-1])
    batch_idx_set.add(batch_idx)

In [13]:
missing_batches = set(range(1400)) - batch_idx_set

In [15]:
missing_batches

{443,
 444,
 445,
 446,
 447,
 448,
 449,
 450,
 451,
 452,
 453,
 454,
 455,
 456,
 457,
 458,
 459,
 762,
 763,
 764,
 765,
 766,
 767,
 768,
 769,
 770,
 771,
 772,
 773,
 774,
 775,
 776,
 777,
 778,
 779}

In [14]:
for missing_idx in missing_batches:
    process_batch(
        missing_idx,
        output_dir=Path("../data/processed/csd_sdf_batches"),
        error_log_schema=error_log_schema,
        batch_size=1_000,
        database_size=1394755,
    )



