From 75d13ea9d347e3e4627694a1c5e05a2c5bdae5b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Wed, 30 Nov 2022 15:45:02 +0100 Subject: [PATCH 1/2] Create dmdataset (dataset for graph-level regression tasks) --- deeptrack/datasets/__init__.py | 1 + deeptrack/datasets/dmdataset/__init__.py | 3 + deeptrack/datasets/dmdataset/checksums.tsv | 3 + deeptrack/datasets/dmdataset/dmdataset.py | 136 +++++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 deeptrack/datasets/dmdataset/__init__.py create mode 100644 deeptrack/datasets/dmdataset/checksums.tsv create mode 100644 deeptrack/datasets/dmdataset/dmdataset.py diff --git a/deeptrack/datasets/__init__.py b/deeptrack/datasets/__init__.py index 9941d71cc..264091f9e 100644 --- a/deeptrack/datasets/__init__.py +++ b/deeptrack/datasets/__init__.py @@ -4,4 +4,5 @@ regression_holography_nanoparticles, segmentation_fluorescence_u2os, detection_holography_nanoparticles, + dmdataset, ) \ No newline at end of file diff --git a/deeptrack/datasets/dmdataset/__init__.py b/deeptrack/datasets/dmdataset/__init__.py new file mode 100644 index 000000000..f84cf036f --- /dev/null +++ b/deeptrack/datasets/dmdataset/__init__.py @@ -0,0 +1,3 @@ +"""dmdataset dataset.""" + +from .dmdataset import Dmdataset diff --git a/deeptrack/datasets/dmdataset/checksums.tsv b/deeptrack/datasets/dmdataset/checksums.tsv new file mode 100644 index 000000000..065db4ead --- /dev/null +++ b/deeptrack/datasets/dmdataset/checksums.tsv @@ -0,0 +1,3 @@ +# TODO(dmdataset): If your dataset downloads files, then the checksums +# will be automatically added here when running +# `tfds build --register_checksums`. diff --git a/deeptrack/datasets/dmdataset/dmdataset.py b/deeptrack/datasets/dmdataset/dmdataset.py new file mode 100644 index 000000000..1026aa3c0 --- /dev/null +++ b/deeptrack/datasets/dmdataset/dmdataset.py @@ -0,0 +1,136 @@ +"""dmdataset dataset.""" + +import tensorflow_datasets as tfds +import tensorflow as tf + +import os +import scipy + +_DESCRIPTION = """ +""" + +_CITATION = """ +""" + + +class Dmdataset(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for dmdataset dataset.""" + + VERSION = tfds.core.Version("1.0.0") + RELEASE_NOTES = { + "1.0.0": "Initial release.", + } + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + + NODE_FEATURES = self.get_features() + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict( + { + "graph": tfds.features.FeaturesDict( + { + **{ + key: tfds.features.Tensor( + shape=feature[0], + dtype=feature[1], + ) + for key, feature in NODE_FEATURES[ + "graph" + ].items() + }, + } + ), + "labels": tfds.features.FeaturesDict( + { + **{ + key: tfds.features.Tensor( + shape=feature[0], + dtype=feature[1], + ) + for key, feature in NODE_FEATURES[ + "labels" + ].items() + }, + } + ), + "sets": tfds.features.FeaturesDict( + { + **{ + key: tfds.features.Tensor( + shape=feature[0], + dtype=feature[1], + ) + for key, feature in NODE_FEATURES[ + "sets" + ].items() + }, + } + ), + } + ), + supervised_keys=None, + homepage="https://dataset-homepage/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + path = dl_manager.download_and_extract( + "https://drive.google.com/u/1/uc?id=19vplN2lbKo4KAmv4NRU2qr3NSlzxFzrx&export=download" + ) + + return { + "train": self._generate_examples( + os.path.join(path, "dmdataset", "training") + ), + "test": self._generate_examples( + os.path.join(path, "dmdataset", "validation") + ), + } + + def _generate_examples(self, path): + """Yields examples.""" + data = [{}, {}, {}] + for i, subdict in enumerate(self.get_features().values()): + files = (*subdict.keys(),) + print(files) + + for file in files: + data_elem = scipy.sparse.load_npz( + os.path.join(path, file + ".npz") + ).toarray() + data_elem = ( + data_elem[0] if data_elem.shape[0] == 1 else data_elem + ) + + data[i][file] = data_elem + + yield "key", { + "graph": data[0], + "labels": data[1], + "sets": data[2], + } + + def get_features(self): + return { + "graph": { + "frame": [(None, 1), tf.float64], + "node_features": [(None, 3), tf.float64], + "edge_features": [(None, 1), tf.float64], + "edge_indices": [(None, 2), tf.int64], + "edge_dropouts": [(None, 2), tf.float64], + }, + "labels": { + "node_labels": [(None,), tf.float64], + "edge_labels": [(None,), tf.float64], + "global_labels": [(None, 3), tf.float64], + }, + "sets": { + "node_sets": [(None, 2), tf.int64], + "edge_sets": [(None, 3), tf.int64], + }, + } From 7c67ea8c162237155dfe8a56398468082412982a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Wed, 30 Nov 2022 15:45:38 +0100 Subject: [PATCH 2/2] Update gnns/__init__.py --- deeptrack/models/gnns/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deeptrack/models/gnns/__init__.py b/deeptrack/models/gnns/__init__.py index 99dd1803f..9dded5dbd 100644 --- a/deeptrack/models/gnns/__init__.py +++ b/deeptrack/models/gnns/__init__.py @@ -1,4 +1,5 @@ from .models import * from .graphs import * +from .generators import * from .utils import *