In [18]:
from molalchemy.rdkit import types, functions, index
from sqlalchemy import (
    Column,
    Integer,
    String,
    Float,
    Boolean,
    ForeignKey,
    engine,
    select,
    text,
)
from sqlalchemy.orm import sessionmaker, Mapped
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, mapped_column

eng = engine.create_engine(
    "postgresql+psycopg://postgres:example@localhost:5432/postgres"
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=eng)
with SessionLocal() as session:
    session.execute(text("CREATE EXTENSION IF NOT EXISTS rdkit"))
    session.commit()
    print(
        session.execute(text("SELECT rdkit_version(), rdkit_toolkit_version()")).all()
    )

[('0.76.0', '2024.03.1')]


If you are unable to get the version of the RDKit toolkit, you might need to activate the extension:

```sql
"CREATE EXTENSION IF NOT EXISTS rdkit"
```

In [19]:
class Base(MappedAsDataclass, DeclarativeBase):
    pass


class Molecule(Base):
    __tablename__ = "molecules"
    __table_args__ = (index.RdkitIndex("mol_gist_idx", "mol"),)
    id: Mapped[int] = mapped_column(
        Integer, primary_key=True, autoincrement=True, init=False
    )
    name: Mapped[str] = mapped_column(String(100), unique=True)
    mol: Mapped[bytes] = mapped_column(types.RdkitMol)
    is_nsaid: Mapped[bool] = mapped_column(Boolean, default=False)


Molecule.__table__.drop(eng, checkfirst=True)
Molecule.metadata.create_all(eng, checkfirst=False)

In [20]:
data = [
    {"name": "Aspirin", "mol": "CC(=O)OC1=CC=CC=C1C(=O)O", "is_nsaid": True},
    {
        "name": "Loratadine",
        "mol": "O=C(OCC)N4CC/C(=C2/c1ccc(Cl)cc1CCc3cccnc23)CC4",
        "is_nsaid": False,
    },
    {
        "name": "Rofecoxib",
        "mol": "O=C2OC\C(=C2\c1ccccc1)c3ccc(cc3)S(C)(=O)=O",
        "is_nsaid": True,
    },
    {"name": "Captopril", "mol": "C[C@H](CS)C(=O)N1CCC[C@H]1C(=O)O", "is_nsaid": False},
    {
        "name": "Talidomide",
        "mol": "O=C1c2ccccc2C(=O)N1C3CCC(=O)NC3=O",
        "is_nsaid": False,
    },
]
mols = [Molecule(**d) for d in data]
session.add_all(mols)
session.commit()

In [21]:
# simple query to get all non-nsaid molecules
session.execute(select(Molecule).where(Molecule.is_nsaid == False)).all()

[(Molecule(id=2, name='Loratadine', mol='CCOC(=O)N1CCC(=C2c3ccc(Cl)cc3CCc3cccnc32)CC1', is_nsaid=False),),
 (Molecule(id=4, name='Captopril', mol='C[C@H](CS)C(=O)N1CCC[C@H]1C(=O)O', is_nsaid=False),),
 (Molecule(id=5, name='Talidomide', mol='O=C1CCC(N2C(=O)c3ccccc3C2=O)C(=O)N1', is_nsaid=False),)]

In [22]:
session.execute(
    select(Molecule).where(
        functions.mol.equals(Molecule.mol, "CC(=O)OC1=CC=CC=C1C(=O)O")
    )
).all()

[(Molecule(id=1, name='Aspirin', mol='CC(=O)Oc1ccccc1C(=O)O', is_nsaid=True),)]

In [23]:
session.execute(
    select(Molecule).where(functions.mol.has_substructure(Molecule.mol, "S"))
).all()

[(Molecule(id=3, name='Rofecoxib', mol='CS(=O)(=O)c1ccc(C2=C(c3ccccc3)C(=O)OC2)cc1', is_nsaid=True),),
 (Molecule(id=4, name='Captopril', mol='C[C@H](CS)C(=O)N1CCC[C@H]1C(=O)O', is_nsaid=False),)]

In [24]:
from sqlalchemy import Computed


class Base(MappedAsDataclass, DeclarativeBase):
    pass


class MoleculeFP(Base):
    __tablename__ = "molecules_fp"
    __table_args__ = (
        index.RdkitIndex("mol_gist_idx_2", "mol"),
        index.RdkitIndex("fp_gist_idx_2", "fp"),
    )
    id: Mapped[int] = mapped_column(
        Integer, primary_key=True, autoincrement=True, init=False
    )
    name: Mapped[str] = mapped_column(String(100), unique=True)
    mol: Mapped[bytes] = mapped_column(types.RdkitMol)
    fp: Mapped[bytes] = mapped_column(
        types.RdkitSparseFingerprint,
        Computed(functions.mol.morgan_fp(mol, 2), persisted=True),
        init=False,
    )
    is_nsaid: Mapped[bool] = mapped_column(Boolean, default=False)


MoleculeFP.__table__.drop(eng, checkfirst=True)
MoleculeFP.__table__.create(eng, checkfirst=True)

In [25]:
session = SessionLocal()
session.add_all([MoleculeFP(**d) for d in data])
session.commit()

In [26]:
session.execute(select(MoleculeFP).limit(1)).all()

[(MoleculeFP(id=1, name='Aspirin', mol='CC(=O)Oc1ccccc1C(=O)O', fp='\\x0100000004000000ffffffff190000004034df0502000000177ce7070100000050d6601e01000000b ... (218 characters truncated) ... 00000c76fa289010000000b9cd89e01000000c831f8a501000000a7d50bb2010000006455c5bf02000000515fd9bf04000000f9fb51d301000000afbc69ee02000000', is_nsaid=True),)]

In [27]:
target_fp = functions.mol.morgan_fp("Clc4cc2c(C(c1ncccc1CC2)=C3CCNCC3)cc4", 2)
sim_expr = functions.fp.tanimoto(MoleculeFP.fp, target_fp).label("similarity")
final_query = select(sim_expr, MoleculeFP).order_by(sim_expr.desc())
session.execute(final_query).all()

[(0.6436781609195402, MoleculeFP(id=2, name='Loratadine', mol='CCOC(=O)N1CCC(=C2c3ccc(Cl)cc3CCc3cccnc32)CC1', fp='\\x0100000004000000ffffffff3400000043892d0201000000727bf20 ... (677 characters truncated) ... 0000ac3e61ce01000000269925d3010000004aff1ee10100000030fb63e50100000056c98fe6010000003abe53ed0100000060a959ed01000000d6f410ee01000000', is_nsaid=False)),
 (0.18, MoleculeFP(id=5, name='Talidomide', mol='O=C1CCC(N2C(=O)c3ccccc3C2=O)C(=O)N1', fp='\\x0100000004000000ffffffff1e0000003a39a100040000004034df0502000000 ... (316 characters truncated) ... 00006455c5bf06000000515fd9bf040000002786c4c00100000015ba53cf0100000095c2ecd901000000af7c8fda0100000060a959ed02000000afbc69ee02000000', is_nsaid=False)),
 (0.1651376146788991, MoleculeFP(id=3, name='Rofecoxib', mol='CS(=O)(=O)c1ccc(C2=C(c3ccccc3)C(=O)OC2)cc1', fp='\\x0100000004000000ffffffff220000003a39a100010000004034df0503 ... (385 characters truncated) ... 00000cb1288e7020000004044bae9010000003abe53ed01000000afbc69ee02000000a42c92f10100000

In [28]:
print(final_query.compile(eng, compile_kwargs={"literal_binds": True}))

SELECT tanimoto_sml(molecules_fp.fp, morgan_fp('Clc4cc2c(C(c1ncccc1CC2)=C3CCNCC3)cc4', 2)) AS similarity, molecules_fp.id, molecules_fp.name, molecules_fp.mol, molecules_fp.fp, molecules_fp.is_nsaid 
FROM molecules_fp ORDER BY similarity DESC


In [29]:
class Base(MappedAsDataclass, DeclarativeBase):
    pass


class RawMoleculeFP(Base):
    __tablename__ = "raw_molecules_fp"
    __table_args__ = (
        index.RdkitIndex("mol_gist_idx_3", "mol"),
        index.RdkitIndex("fp_gist_idx_3", "fp"),
    )
    id: Mapped[int] = mapped_column(
        Integer, primary_key=True, autoincrement=True, init=False
    )
    name: Mapped[str] = mapped_column(String(100), unique=True)
    mol: Mapped[bytes] = mapped_column(types.RdkitMol(return_type="bytes"))
    fp: Mapped[bytes] = mapped_column(
        types.RdkitSparseFingerprint,
        Computed(functions.mol.morgan_fp(mol, 2), persisted=True),
        init=False,
    )
    is_nsaid: Mapped[bool] = mapped_column(Boolean, default=False)


RawMoleculeFP.__table__.drop(eng, checkfirst=True)
RawMoleculeFP.__table__.create(eng, checkfirst=True)

In [30]:
session = SessionLocal()
session.add_all([RawMoleculeFP(**d) for d in data])
session.commit()

In [31]:
session.execute(select(RawMoleculeFP).limit(2)).all()

[(RawMoleculeFP(id=1, name='Aspirin', mol=b'\xef\xbe\xad\xde\x00\x00\x00\x00\x10\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x00\x00\x00\r\x00\x00\x00 ... (970 characters truncated) ... 00000c76fa289010000000b9cd89e01000000c831f8a501000000a7d50bb2010000006455c5bf02000000515fd9bf04000000f9fb51d301000000afbc69ee02000000', is_nsaid=True),),
 (RawMoleculeFP(id=2, name='Loratadine', mol=b'\xef\xbe\xad\xde\x00\x00\x00\x00\x10\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1b\x00\x00\x00\x1e\x00\ ... (2084 characters truncated) ... 0000ac3e61ce01000000269925d3010000004aff1ee10100000030fb63e50100000056c98fe6010000003abe53ed0100000060a959ed01000000d6f410ee01000000', is_nsaid=False),)]