# Dataset generation

This notebook is used to generated the final dataset to train our machine learning models.

In [1]:
import os
import math
import pandas as pd
from tqdm import tqdm
from rdkit import Chem

In [2]:
FINAL_DIR = "../data/processed"

The main aim is to generate a standard dataset for each resource to facilitate easy merging. Each source thus has the following columns:
1. cmp_id
1. smiles
1. inchikey
1. bact
1. strain
1. val
1. source

In [3]:
def process_bindingdb(df):
    """Standardize the BindindDB dataset."""
    cols = [
        "Ligand SMILES",
        "Ligand InChI Key",
        "Organism",
        "Ki (nM)",
        "IC50 (nM)",
        "Kd (nM)",
        "EC50 (nM)",
        "PubChem CID",
        "strain_type",
        "source",
    ]
    df = df[cols]

    formatted_list = []

    for row in tqdm(df.values, desc="Processing BindingDB"):
        (
            smiles,
            inchikey,
            organism,
            ki_val,
            ic50_val,
            kd_val,
            ec50_val,
            pubchem_idx,
            strain_type,
            source,
        ) = row

        if pd.notna(ki_val):
            if ">" in ki_val:
                ki_val = ki_val.split(">")[1]
            elif "<" in ki_val:  # Skip values that are less than
                ki_val = 0

            try:
                log_val = round(9 - math.log10(float(ki_val)), 2)

                formatted_list.append(
                    {
                        "cmp_id": f"pubchem.compound:{pubchem_idx}",
                        "smiles": smiles,
                        "inchikey": inchikey,
                        "bact": organism,
                        "strain": strain_type,
                        "val": log_val,
                        "val_type": "Ki",
                        "source": source,
                    }
                )
            except Exception as e:
                pass

        if pd.notna(ic50_val):
            if ">" in ic50_val:
                ic50_val = ic50_val.split(">")[1]
            elif "<" in ic50_val:  # Skip values that are less than
                ic50_val = 0

            try:
                log_val = round(9 - math.log10(float(ic50_val)), 2)

                formatted_list.append(
                    {
                        "cmp_id": f"pubchem.compound:{pubchem_idx}",
                        "smiles": smiles,
                        "inchikey": inchikey,
                        "bact": organism,
                        "strain": strain_type,
                        "val": log_val,
                        "val_type": "IC50",
                        "source": source,
                    }
                )
            except Exception as e:
                pass

        if pd.notna(kd_val):
            if ">" in kd_val:
                kd_val = kd_val.split(">")[1]
            elif "<" in kd_val:  # Skip values that are less than
                kd_val = 0

            try:
                log_val = round(9 - math.log10(float(kd_val)), 2)

                formatted_list.append(
                    {
                        "cmp_id": f"pubchem.compound:{pubchem_idx}",
                        "smiles": smiles,
                        "inchikey": inchikey,
                        "bact": organism,
                        "strain": strain_type,
                        "val": log_val,
                        "val_type": "Kd",
                        "source": source,
                    }
                )
            except Exception as e:
                pass

        if pd.notna(ec50_val):
            if ">" in ec50_val:
                ec50_val = ec50_val.split(">")[1]
            elif "<" in ec50_val:
                ec50_val = 0

            try:
                log_val = round(9 - math.log10(float(ec50_val)), 2)

                formatted_list.append(
                    {
                        "cmp_id": f"pubchem.compound:{pubchem_idx}",
                        "smiles": smiles,
                        "inchikey": inchikey,
                        "bact": organism,
                        "strain": strain_type,
                        "val": log_val,
                        "val_type": "EC50",
                        "source": source,
                    }
                )
            except Exception as e:
                pass

    return pd.DataFrame(formatted_list)


def process_chembl(df):
    """Standardize the ChEMBL dataset."""
    cols = [
        "chembl_idx",
        "inchi_key",
        "smiles",
        "assay_organism",
        "pchembl_value",
        "assay_type",
        "strain_type",
        "source",
    ]
    df = df[cols]

    formatted_list = []

    for row in tqdm(df.values, desc="Processing ChEMBL"):
        (
            drug_id,
            inchikey,
            smiles,
            organism,
            activity,
            activity_type,
            strain_type,
            source,
        ) = row

        formatted_list.append(
            {
                "cmp_id": f"chembl:{drug_id}",
                "smiles": smiles,
                "inchikey": inchikey,
                "bact": organism,
                "strain": strain_type,
                "val": activity,
                "val_type": activity_type,
                "source": source,
            }
        )

    return pd.DataFrame(formatted_list)


def process_coadd(df):
    """Standardize the CO-ADD dataset."""
    # All values normalized to uM units
    df.drop(columns=["DRVAL_UNIT"], inplace=True)

    formatted_list = []

    for row in tqdm(df.values, desc="Processing CO-ADD"):
        (
            drug_id,
            drug_name,
            smiles,
            organism,
            activity_type,
            activity,
            strain_type,
            source,
        ) = row

        idx = drug_id.split(":")[1]

        if "=" in activity:
            activity = activity.split("=")[1]

        if ">" in activity:
            activity = activity.split(">")[1]
        elif "<" in activity:
            continue

        log_val = round(6 - math.log10(float(activity)), 2)

        formatted_list.append(
            {
                "cmp_id": f"coadd:{idx}",
                "smiles": smiles,
                "inchikey": Chem.MolToInchiKey(Chem.MolFromSmiles(smiles)),
                "bact": organism,
                "strain": strain_type,
                "val": log_val,
                "val_type": activity_type,
                "source": source,
            }
        )

    return pd.DataFrame(formatted_list)


def process_drugcentral(df: pd.DataFrame):
    """Standardize the DrugCentral dataset."""
    # All values normalized to pUnit values and relations are all "="
    df.drop(columns=["ACT_UNIT", "RELATION", "ACT_SOURCE"], inplace=True)

    formatted_list = []

    for row in tqdm(df.values, desc="Processing DrugCentral"):
        (
            drug_name,
            drug_id,
            target_name,
            target_id,
            activity,
            activity_type,
            organism,
            smiles,
            inchikey,
            strain_type,
            source,
        ) = row

        formatted_list.append(
            {
                "cmp_id": f"drugcentral:{drug_id}",
                "smiles": smiles,
                "inchikey": inchikey,
                "bact": organism,
                "strain": strain_type,
                "val": float(activity),
                "val_type": activity_type,
                "source": source,
            }
        )

    return pd.DataFrame(formatted_list)


def process_spark(df):
    """Standardize the SPARK dataset."""
    formatted_list = []

    for row in tqdm(df.values, desc="Processing SPARK"):
        (
            drug_id,
            smiles,
            organism,
            strain_type,
            mic_val,
            ic50_val,
            source,
        ) = row

        try:
            inchikey = Chem.MolToInchiKey(Chem.MolFromSmiles(smiles))
        except Exception as e:
            pass

        if pd.notna(mic_val) and mic_val != 0.0:
            formatted_list.append(
                {
                    "cmp_id": f"spark:{drug_id}",
                    "smiles": smiles,
                    "inchikey": inchikey,
                    "bact": organism,
                    "strain": strain_type,
                    "val": float(mic_val),
                    "val_type": "MIC",
                    "source": source,
                }
            )

        if pd.notna(ic50_val) and ic50_val != 0.0:
            formatted_list.append(
                {
                    "cmp_id": f"spark:{drug_id}",
                    "smiles": smiles,
                    "inchikey": inchikey,
                    "bact": organism,
                    "strain": strain_type,
                    "val": float(ic50_val),
                    "val_type": "IC50",
                    "source": source,
                }
            )

    return pd.DataFrame(formatted_list)

In [4]:
df_list = []

for file in os.listdir(FINAL_DIR):
    if "bacterial" not in file:
        continue

    df = pd.read_csv(os.path.join(FINAL_DIR, file), sep="\t")
    file_name = file.split(".")[0].split("_")[1]
    df["source"] = file_name

    if file_name == "drugcentral":
        df = process_drugcentral(df)
    elif file_name == "coadd":
        df = process_coadd(df)
    elif file_name == "bindingdb":
        df = process_bindingdb(df)
    elif file_name == "spark":
        df = process_spark(df)
    elif file_name == "chembl":
        df = process_chembl(df)

    df_list.append(df)

Processing DrugCentral: 100%|██████████| 433/433 [00:00<00:00, 915389.94it/s]
Processing BindingDB: 100%|██████████| 29372/29372 [00:00<00:00, 534278.90it/s]
Processing CO-ADD: 100%|██████████| 25290/25290 [00:05<00:00, 4312.00it/s]
Processing SPARK:  85%|████████▌ | 28216/33030 [00:07<00:01, 3547.68it/s][11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] Invalid InChI prefix in generating InChI Key
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] Invalid InChI prefix in generating InChI Key
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] bond type above 3 (17) is treated as unspecified!
[11:28:21] Invalid InChI prefix i

In [5]:
final_df = pd.concat(df_list)
final_df.head(2)

Unnamed: 0,cmp_id,smiles,inchikey,bact,strain,val,val_type,source
0,drugcentral:21,NC1=NC2=NC=C(CNC3=CC=C(C=C3)C(=O)N[C@@H](CCC(O...,TVZGACDUOSZQKY-LBPRGKRZSA-N,Lactobacillus casei,acid-fast,8.3,IC50,drugcentral
1,drugcentral:21,NC1=NC2=NC=C(CNC3=CC=C(C=C3)C(=O)N[C@@H](CCC(O...,TVZGACDUOSZQKY-LBPRGKRZSA-N,Escherichia coli,gram-negative,7.96,IC50,drugcentral


In [6]:
final_df["source"].value_counts()

chembl         44543
bindingdb      29183
coadd          25290
spark          23280
drugcentral      433
Name: source, dtype: int64

In [7]:
final_df.to_csv("../data/processed/combined.tsv", sep="\t", index=False)

In [8]:
final_df["strain"].value_counts()

gram-negative    74677
gram-positive    33480
acid-fast        14572
Name: strain, dtype: int64