In [1]:
import os
import re
from datetime import datetime
from io import StringIO
from pathlib import Path

import pandas as pd

In [2]:
metadata_file = "downloads.tsv"

data_dir = Path("./data/sdmbench")
result_dir = Path("./results/sdmbench")

performance_metric_file = "./results/runtime_memory.tsv"

seed = 42

In [3]:
result_dir.mkdir(exist_ok=True, parents=True)

In [4]:
metadata = pd.read_table(metadata_file, index_col=0)

metadata = metadata.assign(path=lambda df: data_dir / df["url"].str.split("/").str[-1])

# Settings

In [5]:
n_seeds = 5

In [6]:
n_pcs = 30
n_genes = 3_000

In [7]:
spatial_weights_delaunay = {
    "osmFISH": 0.8,
    "MERFISH": 1,
    "BARISTAseq": 1.2,
    "STARmap": 1,
    "STARmap*": 0.8,
}

In [8]:
spatial_weights_kNN10 = {
    "osmFISH": 1.2,
    "MERFISH": 1.8,
    "BARISTAseq": 1.8,
    "STARmap": 1.6,
    "STARmap*": 1.4,
}

In [9]:
spatial_weight_stereoseq = 0.8

In [10]:
conda_env = "spatialleiden"
conda_path = "~/miniconda3/bin/activate"

log_dir = result_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True)

conda_cmd = f"source {conda_path} {conda_env}"

path = Path(os.getcwd())

# Submit jobs

In [11]:
start = datetime.now()

In [12]:
job_ids = []

In [13]:
preprocessing = [("HVG", "PCA"), ("SVG", "PCA"), ("SVG", "msPCA")]

In [14]:
# Stereo-seq -> only sequencing-based method

for sample in metadata.loc[lambda df: df["technology"] == "Stereo-seq"].itertuples():
    for seed in range(n_seeds):
        for features, dimred in preprocessing:
            name = (
                f"{sample.Index}_seed{seed}_neighbors-grid4_FS-{features}_DR-{dimred}"
            )
            out = result_dir / f"{name}.tsv"
            cmd = (
                f"{path/'benchmark.py'} "
                f"-i {sample.path} "
                f"-o {out} "
                f"--spatial_weight {spatial_weight_stereoseq} "
                f"--n_pcs {n_pcs} "
                f"--n_genes {n_genes} "
                f"--seed {seed} "
                "--stereoseq "
            )
            cmd += "--svg " if features == "SVG" else ""
            cmd += "--mspca " if dimred == "msPCA" else ""

            id_string = os.popen(
                f"sbatch -J {name} --mem=5G -n 8 -N 1 "
                f"-o {log_dir/name}.txt "
                f'--wrap="{conda_cmd} && {cmd}" '
            ).read()
            if features == "HVG" and dimred == "PCA":
                job_ids.append(id_string)

In [15]:
# all imaging-based methods

for sample in metadata.loc[lambda df: df["technology"] != "Stereo-seq"].itertuples():
    for seed in range(5):
        for neighbors in ["delaunay", 10]:
            for dimred in ["PCA", "msPCA"]:
                name = f"{sample.Index}_seed{seed}_neighbors-{neighbors}_DR-{dimred}"
                out = result_dir / f"{name}.tsv"
                if neighbors == "delaunay":
                    w = spatial_weights_delaunay[sample.technology]
                else:
                    w = spatial_weights_kNN10[sample.technology]
                cmd = (
                    f"{path/'benchmark.py'} "
                    f"-i {sample.path} "
                    f"-o {out} "
                    f"--spatial_weight {w} "
                    f"--n_pcs {n_pcs} "
                    f"--neighbors {neighbors} "
                    f"--seed {seed} "
                )
                cmd += "--mspca " if dimred == "msPCA" else ""

                id_string = os.popen(
                    f"sbatch -J {name} --mem=5G -n 8 -N 1 "
                    f"-o {log_dir/name}.txt "
                    f'--wrap="{conda_cmd} && {cmd}" '
                ).read()
                if dimred == "PCA":
                    job_ids.append(id_string)

# Runtime metrics

Wait until all jobs finished!

In [16]:
def parse_job_id(bsub_out):
    return int(re.search("(\d+)$", bsub_out).group(1))


job_ids = [parse_job_id(j_id.strip()) for j_id in job_ids]

In [17]:
job_stats = os.popen(
    (
        "sacct "
        f"--starttime {start.strftime('%Y-%m-%d%H:%M:%S')} "
        "--format='JobID,Jobname%50,TotalCPU,ElapsedRaw,MaxRSS' "
        "-P "
        "--delimiter=$'\t' "
        "--units=M "
    )
).read()

In [18]:
job_stats = pd.read_table(StringIO(job_stats))

In [19]:
from datetime import timedelta


def parse_sacct_time(time_str):
    pattern = r"(?:(?P<day>\d+)-)?(?:(?P<hour>\d{1,2}):)?(?P<min>\d{2}):(?P<sec>\d{2})(?:.(?P<usec>\d+))?"
    match = re.match(pattern, time_str)

    if not match:
        raise ValueError("Invalid SLURM time format")

    days = int(match.group("day")) if match.group("day") else 0
    hours = int(match.group("hour")) if match.group("hour") else 0
    minutes = int(match.group("min"))
    seconds = int(match.group("sec"))
    microseconds = int(match.group("usec")) if match.group("usec") else 0

    time = timedelta(
        days=days,
        hours=hours,
        minutes=minutes,
        seconds=seconds,
        microseconds=microseconds,
    )

    return time.total_seconds()

In [20]:
cpu_stats = (
    job_stats.loc[
        lambda df: ~df["JobID"].str.contains(".", regex=False),
        ["JobID", "JobName", "TotalCPU", "ElapsedRaw"],
    ]
    .assign(
        JobID=lambda df: df["JobID"].astype(int),
        TotalCPU=lambda df: df["TotalCPU"].map(parse_sacct_time),
    )
    .set_index("JobID")
    .rename(columns={"TotalCPU": "CPU time [s]", "ElapsedRaw": "wall time [s]"})
)

In [21]:
memory_stats = (
    job_stats.loc[
        lambda df: df["JobID"].str.contains(".batch", regex=False), ["JobID", "MaxRSS"]
    ]
    .assign(
        JobID=lambda df: df["JobID"].str.extract("(\d+)").astype(int),
        MaxRSS=lambda df: df["MaxRSS"].str.extract("([\d\\.]+)").astype(float),
    )
    .set_index("JobID")
    .rename(columns={"MaxRSS": "max memory [MB]"})
)

In [22]:
stats = cpu_stats.join(memory_stats).loc[lambda df: df.index.isin(job_ids)]

In [23]:
stats.to_csv(performance_metric_file, sep="\t")