Skip to content

Commit

Permalink
Merge eb8787a into e9dfcc5
Browse files Browse the repository at this point in the history
  • Loading branch information
srivarra committed Jan 9, 2023
2 parents e9dfcc5 + eb8787a commit 329c78d
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 15 deletions.
4 changes: 4 additions & 0 deletions ark/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@
"ARCHIVE": ["tar", "gz", "zip"],
"DATA": ["csv", "feather", "bin", "json"],
}

# Switch it from `main` to the commit ID on HuggingFace to test a
# specific version of the Example Dataset
EXAMPLE_DATASET_REVISION: str = "main"
15 changes: 8 additions & 7 deletions ark/utils/data_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os
import shutil
import tempfile
from random import randint
from shutil import rmtree

import feather
import numpy as np
import pandas as pd
import pytest
import skimage.io as io
import xarray as xr

from shutil import rmtree
from random import randint

from ark import settings
from ark.utils import data_utils, test_utils, io_utils, load_utils
from ark.utils import data_utils, io_utils, load_utils, test_utils
from ark.utils.data_utils import (generate_and_save_cell_cluster_masks,
generate_and_save_pixel_cluster_masks,
generate_and_save_neighborhood_cluster_masks,
generate_and_save_pixel_cluster_masks,
label_cells_by_cluster, relabel_segmentation)

parametrize = pytest.mark.parametrize
Expand Down Expand Up @@ -484,7 +484,8 @@ def test_generate_and_save_neighborhood_cluster_masks(sub_dir, name_suffix):
for fov in fovs:
io.imsave(
os.path.join(temp_dir, 'seg_dir', fov + '_whole_cell.tiff'),
sample_label_maps.loc[fov, ...].values
sample_label_maps.loc[fov, ...].values,
check_contrast=False
)

generate_and_save_neighborhood_cluster_masks(
Expand Down Expand Up @@ -755,7 +756,7 @@ def test_stitch_images_by_shape(segmentation, clustering, subdir, fovs):
data_utils.stitch_images_by_shape(data_dir, stitched_dir, img_sub_folder=subdir,
segmentation=segmentation, clustering=clustering)
assert sorted(io_utils.list_files(stitched_dir)) == \
[chan + '_stitched.tiff' for chan in chans]
[chan + '_stitched.tiff' for chan in chans]

# stitched image is 3 x 2 fovs with max_img_size = 10
stitched_data = load_utils.load_imgs_from_dir(stitched_dir,
Expand Down
23 changes: 18 additions & 5 deletions ark/utils/example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datasets

from ark.settings import EXAMPLE_DATASET_REVISION
from ark.utils.misc_utils import verify_in_list


Expand Down Expand Up @@ -127,7 +128,7 @@ def move_example_dataset(self, move_dir: Union[str, pathlib.Path]):
[f.unlink() for f in dst_path.glob("*") if f.is_file()]
# Fill destination path
shutil.copytree(src_path, dst_path, dirs_exist_ok=True,
ignore=shutil.ignore_patterns(r".!*"))
ignore=shutil.ignore_patterns(r"\.\!*"))
else:
if empty_dst_path:
warnings.warn(UserWarning(f"Files do not exist in {dst_path}. \
Expand All @@ -148,7 +149,17 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path],
Args:
dataset (str): The dataset to download for a particular notebook.
dataset (str): The name of the dataset to download. Can be one of
* `"segment_image_data"`
* `"cluster_pixels"`
* `"cluster_cells"`
* `"post_clustering"`
* `"fiber_segmentation"`
* `"LDA_preprocessing"`
* `"LDA_training_inference"`
* `"neighborhood_analysis"`
* `"pairwise_spatial_enrichment"`
save_dir (Union[str, pathlib.Path]): The path to save the dataset files in.
overwrite_existing (bool): The option to overwrite existing configs of the `dataset`
downloaded. Defaults to True.
Expand All @@ -168,11 +179,13 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path],
try:
verify_in_list(dataset=dataset, valid_datasets=valid_datasets)
except ValueError:
ValueError(f"The dataset <{dataset}> is not one of the valid datasets available. \
The following are available: { {*valid_datasets} }")
err_str: str = f"""The dataset \"{dataset}\" is not one of the valid datasets available.
The following are available: {*valid_datasets,}"""
raise ValueError(err_str) from None

example_dataset = ExampleDataset(dataset=dataset, overwrite_existing=overwrite_existing,
cache_dir=None,
revision="main")
revision=EXAMPLE_DATASET_REVISION)

# Download the dataset
example_dataset.download_example_dataset()
Expand Down
3 changes: 2 additions & 1 deletion ark/utils/example_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from ark.settings import EXAMPLE_DATASET_REVISION
from ark.utils import test_utils
from ark.utils.example_dataset import ExampleDataset, get_example_dataset

Expand Down Expand Up @@ -32,7 +33,7 @@ def dataset_download(request) -> Iterator[ExampleDataset]:
example_dataset: ExampleDataset = ExampleDataset(
dataset=request.param,
cache_dir=None,
revision="main"
revision=EXAMPLE_DATASET_REVISION
)
# Download example data for a particular notebook
example_dataset.download_example_dataset()
Expand Down
5 changes: 4 additions & 1 deletion ark/utils/notebooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def test_create_output(self):
deepcell_output_dir = self.tb.ref("deepcell_output_dir")
fovs = self.tb.ref("fovs")
# Generate the sample feature_0, feature_1 tiffs
notebooks_test_utils.generate_sample_feature_tifs(fovs, deepcell_output_dir, (1024, 1024))
# Account for the fact that fov0 is 512 x 512
for fov, dim in zip(fovs, [512, 1024]):
notebooks_test_utils.generate_sample_feature_tifs(
[fov], deepcell_output_dir=deepcell_output_dir, img_shape=(dim, dim))

def test_overlay_mask(self):
self.tb.execute_cell("overlay_mask")
Expand Down
2 changes: 1 addition & 1 deletion ark/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def _write_labels(base_dir, fov_names, comp_names, shape, sub_dir, fills, dtype,

for i, fov in enumerate(fov_names):
tiffpath = os.path.join(base_dir, f'{fov}{suffix}.tiff')
io.imsave(tiffpath, label_data[i, :, :, 0], plugin='tifffile')
io.imsave(tiffpath, label_data[i, :, :, 0], plugin='tifffile', check_contrast=False)
filelocs[fov] = tiffpath

return filelocs, label_data
Expand Down

0 comments on commit 329c78d

Please sign in to comment.