diff --git a/deeptrack/datasets/__init__.py b/deeptrack/datasets/__init__.py index c2f15216a..6c1d19218 100644 --- a/deeptrack/datasets/__init__.py +++ b/deeptrack/datasets/__init__.py @@ -6,4 +6,5 @@ detection_holography_nanoparticles, detection_linking_hela, dmdataset, + regression_diffusion_landscape, ) \ No newline at end of file diff --git a/deeptrack/datasets/dmdataset/dmdataset.py b/deeptrack/datasets/dmdataset/dmdataset.py index 1026aa3c0..2d0ea9543 100644 --- a/deeptrack/datasets/dmdataset/dmdataset.py +++ b/deeptrack/datasets/dmdataset/dmdataset.py @@ -97,7 +97,6 @@ def _generate_examples(self, path): data = [{}, {}, {}] for i, subdict in enumerate(self.get_features().values()): files = (*subdict.keys(),) - print(files) for file in files: data_elem = scipy.sparse.load_npz( diff --git a/deeptrack/datasets/regression_diffusion_landscape/__init__.py b/deeptrack/datasets/regression_diffusion_landscape/__init__.py new file mode 100644 index 000000000..55c8b4b38 --- /dev/null +++ b/deeptrack/datasets/regression_diffusion_landscape/__init__.py @@ -0,0 +1,3 @@ +"""regresion_diffusion_landscape dataset.""" + +from .regression_diffusion_landscape import RegressionDiffusionLandscape diff --git a/deeptrack/datasets/regression_diffusion_landscape/checksums.tsv b/deeptrack/datasets/regression_diffusion_landscape/checksums.tsv new file mode 100644 index 000000000..def432726 --- /dev/null +++ b/deeptrack/datasets/regression_diffusion_landscape/checksums.tsv @@ -0,0 +1 @@ +https://drive.google.com/u/1/uc?id=1hiBGuJ0OdcHx6XaNEOqttaw_OmculCXY&export=download 505299704 155269ff2291c4b0e975f939c3e2d719c86098b32893eb9087282e0a0ce0a172 DiffusionLandscapeDataset.zip diff --git a/deeptrack/datasets/regression_diffusion_landscape/dummy_data/TODO-add_fake_data_in_this_directory.txt b/deeptrack/datasets/regression_diffusion_landscape/dummy_data/TODO-add_fake_data_in_this_directory.txt new file mode 100644 index 000000000..e69de29bb diff --git a/deeptrack/datasets/regression_diffusion_landscape/regression_diffusion_landscape.py b/deeptrack/datasets/regression_diffusion_landscape/regression_diffusion_landscape.py new file mode 100644 index 000000000..0e34e74b1 --- /dev/null +++ b/deeptrack/datasets/regression_diffusion_landscape/regression_diffusion_landscape.py @@ -0,0 +1,134 @@ +"""dmdataset dataset.""" + +import tensorflow_datasets as tfds +import tensorflow as tf + +import os +import scipy + +_DESCRIPTION = """ +""" + +_CITATION = """ +""" + + + +class RegressionDiffusionLandscape(tfds.core.GeneratorBasedBuilder): + VERSION = tfds.core.Version("1.0.0") + RELEASE_NOTES = { + "1.0.0": "Initial release.", + } + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + + NODE_FEATURES = self.get_features() + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict( + { + "graph": tfds.features.FeaturesDict( + { + **{ + key: tfds.features.Tensor( + shape=feature[0], + dtype=feature[1], + ) + for key, feature in NODE_FEATURES[ + "graph" + ].items() + }, + } + ), + "labels": tfds.features.FeaturesDict( + { + **{ + key: tfds.features.Tensor( + shape=feature[0], + dtype=feature[1], + ) + for key, feature in NODE_FEATURES[ + "labels" + ].items() + }, + } + ), + "sets": tfds.features.FeaturesDict( + { + **{ + key: tfds.features.Tensor( + shape=feature[0], + dtype=feature[1], + ) + for key, feature in NODE_FEATURES[ + "sets" + ].items() + }, + } + ), + } + ), + supervised_keys=None, + homepage="https://dataset-homepage/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + path = dl_manager.download_and_extract( + "https://drive.google.com/u/1/uc?id=1hiBGuJ0OdcHx6XaNEOqttaw_OmculCXY&export=download" + ) + + return { + "train": self._generate_examples( + os.path.join(path, "DiffusionLandscapeDataset", "training") + ), + "test": self._generate_examples( + os.path.join(path, "DiffusionLandscapeDataset", "validation") + ), + } + + def _generate_examples(self, path): + """Yields examples.""" + data = [{}, {}, {}] + for i, subdict in enumerate(self.get_features().values()): + files = (*subdict.keys(),) + + for file in files: + data_elem = scipy.sparse.load_npz( + os.path.join(path, file + ".npz") + ).toarray() + data_elem = ( + data_elem[0] if data_elem.shape[0] == 1 else data_elem + ) + + data[i][file] = data_elem + + yield "key", { + "graph": data[0], + "labels": data[1], + "sets": data[2], + } + + def get_features(self): + return { + "graph": { + "frame": [(None, 1), tf.float64], + "node_features": [(None, 3), tf.float64], + "edge_features": [(None, 1), tf.float64], + "edge_indices": [(None, 2), tf.int64], + "edge_dropouts": [(None, 2), tf.float64], + }, + "labels": { + "node_labels": [(None,), tf.float64], + "edge_labels": [(None,), tf.float64], + "global_labels": [(None,), tf.float64], + }, + "sets": { + "node_sets": [(None, 2), tf.int64], + "edge_sets": [(None, 3), tf.int64], + }, + } diff --git a/deeptrack/models/gnns/augmentations.py b/deeptrack/models/gnns/augmentations.py index 06c585278..d23b7c129 100644 --- a/deeptrack/models/gnns/augmentations.py +++ b/deeptrack/models/gnns/augmentations.py @@ -37,6 +37,33 @@ def inner(data): return inner +def GetSubGraph(num_nodes, node_start): + def inner(data): + graph, labels, *_ = data + + edge_connects_removed_node = np.any( + (graph[2] < node_start) | (graph[2] >= node_start + num_nodes), + axis=-1, + ) + + node_features = graph[0][node_start : node_start + num_nodes] + edge_features = graph[1][~edge_connects_removed_node] + edge_connections = graph[2][~edge_connects_removed_node] - node_start + weights = graph[3][~edge_connects_removed_node] + + node_labels = labels[0][node_start : node_start + num_nodes] + edge_labels = labels[1][~edge_connects_removed_node] + global_labels = labels[2] + + return (node_features, edge_features, edge_connections, weights), ( + node_labels, + edge_labels, + global_labels, + ) + + return inner + + def GetSubGraphFromLabel(samples): """ Returns a function that takes a graph and returns a subgraph