Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficiency Improvements + Protect Data Temp Directory #238

Merged
merged 14 commits into from
Mar 27, 2024
7 changes: 6 additions & 1 deletion examples/fmri/biobank/2_train_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
inputs = file.read().split("\n")

# Create Data object for training
data = Data(inputs, load_memmaps=False, n_jobs=8)
data = Data(
inputs,
use_tfrecord=True,
store_dir=f"tmp_{id}",
n_jobs=8,
)

# Prepare data
data.standardize()
Expand Down
7 changes: 6 additions & 1 deletion examples/fmri/biobank/4_dual_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
inputs = file.read().split("\n")

# Create Data object for training
data = Data(inputs, load_memmaps=False, n_jobs=8)
data = Data(
inputs,
use_tfrecord=True,
store_dir=f"tmp_{id}",
n_jobs=16,
)

# Prepare data
data.standardize()
Expand Down
9 changes: 7 additions & 2 deletions examples/fmri/biobank/submit_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os

def write_job_script(run, queue="gpu_short"):
def write_job_script(run, queue="short", n_gpus=1, n_cpus=12):
"""Create a job script to submit."""

with open("job.sh", "w") as file:
Expand All @@ -14,12 +14,17 @@ def write_job_script(run, queue="gpu_short"):
file.write(f"#SBATCH -o logs/{name}.out\n")
file.write(f"#SBATCH -e logs/{name}.err\n")
file.write(f"#SBATCH -p {queue}\n")
file.write("#SBATCH --gres gpu:1\n")
if "gpu" in queue:
file.write(f"#SBATCH --gres gpu:{n_gpus}\n")
else:
file.write(f"#SBATCH -c {n_cpus}\n")
file.write("source activate osld\n")
file.write(f"python 2_train_hmm.py {run}\n")

# Create directory to hold log/error files
os.makedirs("logs", exist_ok=True)

# Submit jobs
for run in range(1, 11):
write_job_script(run)
os.system("sbatch job.sh")
Expand Down
6 changes: 5 additions & 1 deletion examples/fmri/hcp/2_train_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
inputs = file.read().split("\n")

# Create Data object for training
data = Data(inputs, load_memmaps=False, n_jobs=8)
data = Data(
inputs,
store_dir=f"tmp_{id}",
n_jobs=8,
)

# Prepare data
data.standardize()
Expand Down
6 changes: 5 additions & 1 deletion examples/fmri/hcp/4_dual_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
inputs = file.read().split("\n")

# Create Data object for training
data = Data(inputs, load_memmaps=False, n_jobs=8)
data = Data(
inputs,
store_dir=f"tmp_{id}",
n_jobs=8,
)

# Prepare data
data.standardize()
Expand Down
9 changes: 7 additions & 2 deletions examples/fmri/hcp/submit_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os

def write_job_script(run, queue="gpu_short"):
def write_job_script(run, queue="short", n_gpus=1, n_cpus=12):
"""Create a job script to submit."""

with open("job.sh", "w") as file:
Expand All @@ -14,12 +14,17 @@ def write_job_script(run, queue="gpu_short"):
file.write(f"#SBATCH -o logs/{name}.out\n")
file.write(f"#SBATCH -e logs/{name}.err\n")
file.write(f"#SBATCH -p {queue}\n")
file.write("#SBATCH --gres gpu:1\n")
if "gpu" in queue:
file.write(f"#SBATCH --gres gpu:{n_gpus}\n")
else:
file.write(f"#SBATCH -c {n_cpus}\n")
file.write("source activate osld\n")
file.write(f"python 2_train_hmm.py {run}\n")

# Create directory to hold log/error files
os.makedirs("logs", exist_ok=True)

# Submit jobs
for run in range(1, 11):
write_job_script(run)
os.system("sbatch job.sh")
Expand Down
5 changes: 3 additions & 2 deletions osl_dynamics/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
/tutorials_build/data_preparation.html>`_
"""

from osl_dynamics.data.base import Data, SessionLabels, load_tfrecord_dataset
from osl_dynamics.data.base import Data, SessionLabels
from osl_dynamics.data.tf import load_tfrecord_dataset

__all__ = ["Data"]
__all__ = ["Data", "SessionLabels", "load_tfrecord_dataset"]
Loading