From 64667cc1eaa1e7e3c6c0583900b96f8bc0d6dee8 Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sun, 30 Oct 2022 13:21:12 +0100 Subject: [PATCH] small fixes to lodestar --- deeptrack/datasets/__init__.py | 1 + .../__init__.py | 3 + .../checksums.tsv | 3 + .../detection_holography_nanoparticles.py | 77 +++++++++++++++++++ ...detection_holography_nanoparticles_test.py | 24 ++++++ .../TODO-add_fake_data_in_this_directory.txt | 0 deeptrack/models/lodestar/equivariances.py | 62 ++++++++------- deeptrack/models/lodestar/models.py | 8 +- setup.py | 2 +- 9 files changed, 150 insertions(+), 30 deletions(-) create mode 100644 deeptrack/datasets/detection_holography_nanoparticles/__init__.py create mode 100644 deeptrack/datasets/detection_holography_nanoparticles/checksums.tsv create mode 100644 deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles.py create mode 100644 deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles_test.py create mode 100644 deeptrack/datasets/detection_holography_nanoparticles/dummy_data/TODO-add_fake_data_in_this_directory.txt diff --git a/deeptrack/datasets/__init__.py b/deeptrack/datasets/__init__.py index ec675718a..9941d71cc 100644 --- a/deeptrack/datasets/__init__.py +++ b/deeptrack/datasets/__init__.py @@ -3,4 +3,5 @@ segmentation_ssTEM_drosophila, regression_holography_nanoparticles, segmentation_fluorescence_u2os, + detection_holography_nanoparticles, ) \ No newline at end of file diff --git a/deeptrack/datasets/detection_holography_nanoparticles/__init__.py b/deeptrack/datasets/detection_holography_nanoparticles/__init__.py new file mode 100644 index 000000000..476486051 --- /dev/null +++ b/deeptrack/datasets/detection_holography_nanoparticles/__init__.py @@ -0,0 +1,3 @@ +"""detection_holography_nanoparticles dataset.""" + +from .detection_holography_nanoparticles import DetectionHolographyNanoparticles diff --git a/deeptrack/datasets/detection_holography_nanoparticles/checksums.tsv b/deeptrack/datasets/detection_holography_nanoparticles/checksums.tsv new file mode 100644 index 000000000..cd65db4e9 --- /dev/null +++ b/deeptrack/datasets/detection_holography_nanoparticles/checksums.tsv @@ -0,0 +1,3 @@ +# TODO(detection_holography_nanoparticles): If your dataset downloads files, then the checksums +# will be automatically added here when running +# `tfds build --register_checksums`. diff --git a/deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles.py b/deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles.py new file mode 100644 index 000000000..440b106e1 --- /dev/null +++ b/deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles.py @@ -0,0 +1,77 @@ +"""detection_holography_nanoparticles dataset.""" + +import tensorflow_datasets as tfds +import tensorflow as tf +import numpy as np + +# TODO(detection_holography_nanoparticles): Markdown description that will appear on the catalog page. +_DESCRIPTION = """ +""" + +# TODO(detection_holography_nanoparticles): BibTeX citation +_CITATION = """ +""" + + +class DetectionHolographyNanoparticles(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for detection_holography_nanoparticles dataset.""" + + VERSION = tfds.core.Version("1.0.2") + RELEASE_NOTES = { + "1.0.0": "Initial release.", + } + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + # TODO(detection_holography_nanoparticles): Specifies the tfds.core.DatasetInfo object + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict( + { + # These are the features of your dataset like images, labels ... + "image": tfds.features.Tensor( + shape=(972, 729, 2), dtype=tf.float64 + ), + "label": tfds.features.Tensor(shape=(None, 7), dtype=tf.float64), + } + ), + # If there's a common (input, target) tuple from the + # features, specify them here. They'll be used if + # `as_supervised=True` in `builder.as_dataset`. + supervised_keys=("image", "label"), # Set to `None` to disable + homepage="https://dataset-homepage/", + citation=_CITATION, + disable_shuffling=True, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + # TODO(detection_holography_nanoparticles): Downloads the data and defines the splits + path = dl_manager.download_and_extract( + "https://drive.google.com/u/1/uc?id=1uAZVr9bldhZhxuXAXvdd1-Ks4m9HPRtM&export=download" + ) + + # TODO(detection_holography_nanoparticles): Returns the Dict[split names, Iterator[Key, Example]] + return { + "train": self._generate_examples(path), + } + + def _generate_examples(self, path): + """Yields examples.""" + # TODO(detection_holography_nanoparticles): Yields (key, example) tuples from the dataset + + fields = path.glob("f*.npy") + labels = path.glob("d*.npy") + + # sort the files + fields = sorted(fields, key=lambda x: int(x.stem[1:])) + labels = sorted(labels, key=lambda x: int(x.stem[1:])) + + for field, label in zip(fields, labels): + field_data = np.load(field) + field_data = np.stack((field_data.real, field_data.imag), axis=-1) + yield field.stem, { + "image": field_data, + "label": np.load(label), + } diff --git a/deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles_test.py b/deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles_test.py new file mode 100644 index 000000000..8c96c6b32 --- /dev/null +++ b/deeptrack/datasets/detection_holography_nanoparticles/detection_holography_nanoparticles_test.py @@ -0,0 +1,24 @@ +"""detection_holography_nanoparticles dataset.""" + +import tensorflow_datasets as tfds +from . import detection_holography_nanoparticles + + +class DetectionHolographyNanoparticlesTest(tfds.testing.DatasetBuilderTestCase): + """Tests for detection_holography_nanoparticles dataset.""" + # TODO(detection_holography_nanoparticles): + DATASET_CLASS = detection_holography_nanoparticles.DetectionHolographyNanoparticles + SPLITS = { + 'train': 3, # Number of fake train example + 'test': 1, # Number of fake test example + } + + # If you are calling `download/download_and_extract` with a dict, like: + # dl_manager.download({'some_key': 'http://a.org/out.txt', ...}) + # then the tests needs to provide the fake output paths relative to the + # fake data directory + # DL_EXTRACT_RESULT = {'some_key': 'output_file1.txt', ...} + + +if __name__ == '__main__': + tfds.testing.test_main() diff --git a/deeptrack/datasets/detection_holography_nanoparticles/dummy_data/TODO-add_fake_data_in_this_directory.txt b/deeptrack/datasets/detection_holography_nanoparticles/dummy_data/TODO-add_fake_data_in_this_directory.txt new file mode 100644 index 000000000..e69de29bb diff --git a/deeptrack/models/lodestar/equivariances.py b/deeptrack/models/lodestar/equivariances.py index f5055b7bc..a30832258 100644 --- a/deeptrack/models/lodestar/equivariances.py +++ b/deeptrack/models/lodestar/equivariances.py @@ -21,22 +21,32 @@ class Equivariance(Feature): Multiplicative equivariance add : float, array-like Additive equivariance - indexes : optional, int or slice + indices : optional, int or slice Index of related predicted value(s) """ - def __init__(self, mul, add, indexes=slice(None, None, 1), **kwargs): - super().__init__(mul=mul, add=add, indexes=indexes, **kwargs) + def __init__(self, mul, add, indices=slice(None, None, 1), indexes=None, **kwargs): + if indexes is not None: + indices = indexes - def get(self, matvec, mul, add, indexes, **kwargs): + super().__init__(mul=mul, add=add, indices=indices, **kwargs) + + def get(self, matvec, mul, add, indices, **kwargs): A, b = matvec._value mulf = np.eye(len(b)) addf = np.zeros((len(b), 1)) - mulf[indexes, indexes] = mul - addf[indexes] = add + + addf[indices] = add + + if isinstance(indices, (slice, int)): + mulf[indices, indices] = mul + else: + for i in indices: + for j in indices: + mulf[i, j] = mul[i, j] A = mulf @ A b = mulf @ b @@ -54,11 +64,11 @@ class TranslationalEquivariance(Equivariance): """ - def __init__(self, translation, indexes=None): - if indexes is None: - indexes = self.get_indexes + def __init__(self, translation, indices=None): + if indices is None: + indices = self.get_indices super().__init__( - translation=translation, add=self.get_add, mul=self.get_mul, indexes=indexes + translation=translation, add=self.get_add, mul=self.get_mul, indices=indices ) def get_add(self, translation): @@ -67,7 +77,7 @@ def get_add(self, translation): def get_mul(self, translation): return np.eye(len(translation)) - def get_indexes(self, translation): + def get_indices(self, translation): return slice(len(translation)) @@ -81,11 +91,11 @@ class Rotational2DEquivariance(Equivariance): """ - def __init__(self, rotate, indexes=None): - if indexes is None: - indexes = self.get_indexes + def __init__(self, rotate, indices=None): + if indices is None: + indices = self.get_indices super().__init__( - rotate=rotate, add=self.get_add, mul=self.get_mul, indexes=indexes + rotate=rotate, add=self.get_add, mul=self.get_mul, indices=indices ) def get_add(self): @@ -95,7 +105,7 @@ def get_mul(self, rotate): s, c = np.sin(rotate), np.cos(rotate) return np.array([[c, s], [-s, c]]) - def get_indexes(self): + def get_indices(self): return slice(2) @@ -109,11 +119,11 @@ class ScaleEquivariance(Equivariance): """ - def __init__(self, scale, indexes=None): - if indexes is None: - indexes = self.get_indexes + def __init__(self, scale, indices=None): + if indices is None: + indices = self.get_indices super().__init__( - scale=scale, add=self.get_add, mul=self.get_mul, indexes=indexes + scale=scale, add=self.get_add, mul=self.get_mul, indices=indices ) def get_add(self, scale): @@ -122,7 +132,7 @@ def get_add(self, scale): def get_mul(self, scale): return np.diag(scale) - def get_indexes(self, scale): + def get_indices(self, scale): return slice(len(scale)) @@ -138,11 +148,11 @@ class LogScaleEquivariance(Equivariance): """ - def __init__(self, scale, indexes=None): - if indexes is None: - indexes = self.get_indexes + def __init__(self, scale, indices=None): + if indices is None: + indices = self.get_indices super().__init__( - scale=scale, add=self.get_add, mul=self.get_mul, indexes=indexes + scale=scale, add=self.get_add, mul=self.get_mul, indices=indices ) def get_add(self, scale): @@ -151,5 +161,5 @@ def get_add(self, scale): def get_mul(self, scale): return np.eye(len(scale)) - def get_indexes(self, scale): + def get_indices(self, scale): return slice(len(scale)) \ No newline at end of file diff --git a/deeptrack/models/lodestar/models.py b/deeptrack/models/lodestar/models.py index 8f92df5bf..b84fa6ac9 100644 --- a/deeptrack/models/lodestar/models.py +++ b/deeptrack/models/lodestar/models.py @@ -282,7 +282,7 @@ def default_model(self, input_shape): return model def predict_and_detect( - self, data, alpha=0.5, beta=0.5, cutoff=0.98, mode="quantile" + self, data, alpha=0.5, beta=0.5, cutoff=0.98, mode="quantile", **predict_kwargs ): """Evaluates the model on a batch of data, and detects objects in each frame @@ -296,9 +296,11 @@ def predict_and_detect( Treshholding parameters. Mode can be either "quantile" or "ratio" or "constant". If "quantile", then `ratio` defines the quantile of scores to accept. If "ratio", then cutoff defines the ratio of the max score as threshhold. If constant, the cutoff is used directly as treshhold. + predict_kwargs: dict + Additional arguments to pass to the predict method. """ - pred, weight = self.predict(data) + pred, weight = self.predict(data, **predict_kwargs) detections = [ self.detect(p, w, alpha=alpha, beta=beta, cutoff=cutoff, mode=mode) for p, w in zip(pred, weight) @@ -307,7 +309,7 @@ def predict_and_detect( def predict_and_pool(self, data, mask=1): """Evaluates the model on a batch of data, and pools the predictions in each frame to a single value. - + Used when it's known a-priori that there is only one object per image. Parameters diff --git a/setup.py b/setup.py index 5c4a665f5..f0b657b6a 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setuptools.setup( name="deeptrack", # Replace with your own username - version="1.4.1", + version="1.5.0a5", author="Benjamin Midtvedt", author_email="benjamin.midtvedt@physics.gu.se", description="A deep learning oriented microscopy image simulation package",