In [0]:
MOLECULES_PATH = "/Volumes/mols_storage/default/mol_data/molecules.csv"
GROUPS_PATH    = "/Volumes/mols_storage/default/mol_data/groups.csv"

molecules_df = (spark.read
    .option("header", "true")
    .option("inferSchema", "true")
    .csv(MOLECULES_PATH))

groups_df = (spark.read
    .option("header", "true")
    .option("inferSchema", "true")
    .csv(GROUPS_PATH))

print("Mols schema:")
molecules_df.printSchema()
print("Groups schema:")
groups_df.printSchema()

display(molecules_df)
display(groups_df)


In [0]:
%pip install rdkit

In [0]:
dbutils.library.restartPython()


In [0]:
from rdkit import Chem
from rdkit.Chem import Descriptors

def calculate_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    return {
        "mol_weight": Descriptors.MolWt(mol),
        "logp": Descriptors.MolLogP(mol),
        "tpsa": Descriptors.TPSA(mol),
        "hbd": Descriptors.NumHDonors(mol),
        "hba": Descriptors.NumHAcceptors(mol),
    }


In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType

desc_schema = StructType([
    StructField("mol_weight", DoubleType(), True),
    StructField("logp", DoubleType(), True),
    StructField("tpsa", DoubleType(), True),
    StructField("hbd", IntegerType(), True),
    StructField("hba", IntegerType(), True),
])

def calculate_descriptors_struct(smiles: str):
    descr = calculate_descriptors(smiles)
    if descr is None:
        return (None, None, None, None, None)
    return (
        float(descr["mol_weight"]),
        float(descr["logp"]),
        float(descr["tpsa"]),
        int(descr["hbd"]),
        int(descr["hba"]),
    )

desc_udf = F.udf(calculate_descriptors_struct, desc_schema)


In [0]:
mols_enr_df = (
    molecules_df
    .withColumn("desc", desc_udf(F.col("smiles")))
    .select("molecule_id", "smiles", "group_id", "desc.*")
    
)

display(mols_enr_df)


In [0]:
invalid_df = mols_enr_df.filter(F.col("mol_weight").isNull())
valid_df = mols_enr_df.filter(F.col("mol_weight").isNotNull())


In [0]:
valid_stats_df = (
    valid_df
    .groupBy("group_id")
    .agg(
        F.min("mol_weight").alias("mol_weight_min"),
        F.max("mol_weight").alias("mol_weight_max"),
        F.avg("mol_weight").alias("mol_weight_avg"),

        F.min("logp").alias("logp_min"),
        F.max("logp").alias("logp_max"),
        F.avg("logp").alias("logp_avg"),

        F.min("tpsa").alias("tpsa_min"),
        F.max("tpsa").alias("tpsa_max"),
        F.avg("tpsa").alias("tpsa_avg"),

        F.min("hbd").alias("hbd_min"),
        F.max("hbd").alias("hbd_max"),
        F.avg("hbd").alias("hbd_avg"),

        F.min("hba").alias("hba_min"),
        F.max("hba").alias("hba_max"),
        F.avg("hba").alias("hba_avg"),
    )

)

display(valid_stats_df)


In [0]:
VALID_MOLECULES_PATH = "/Volumes/mols_storage/default/mol_data/tables/valid_molecules"

(valid_stats_df
 .write
 .mode("overwrite")
 .format("delta")
 .save(VALID_MOLECULES_PATH))



In [0]:
group_statistics_df = (
    valid_stats_df
    .join(groups_df, on="group_id", how="left")
    .select(
        "group_id",
        "owner",
        "description",
        
        "mol_weight_min", "mol_weight_max", "mol_weight_avg",
        "logp_min", "logp_max", "logp_avg",
        "tpsa_min", "tpsa_max", "tpsa_avg",
        "hbd_min", "hbd_max", "hbd_avg",
        "hba_min", "hba_max", "hba_avg",
    )
)

display(group_statistics_df)



In [0]:
GROUP_STATISTICS_PATH = "/Volumes/mols_storage/default/mol_data/tables/group_statistics"

(group_statistics_df
 .write
 .mode("overwrite")
 .format("delta")
 .save(GROUP_STATISTICS_PATH))


In [0]:
group_stats_saved_check = spark.read.format("delta").load(GROUP_STATISTICS_PATH)
display(group_stats_saved_check)
