Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deeptrack/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
detection_holography_nanoparticles,
detection_linking_hela,
dmdataset,
regression_diffusion_landscape,
)
1 change: 0 additions & 1 deletion deeptrack/datasets/dmdataset/dmdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def _generate_examples(self, path):
data = [{}, {}, {}]
for i, subdict in enumerate(self.get_features().values()):
files = (*subdict.keys(),)
print(files)

for file in files:
data_elem = scipy.sparse.load_npz(
Expand Down
3 changes: 3 additions & 0 deletions deeptrack/datasets/regression_diffusion_landscape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""regresion_diffusion_landscape dataset."""

from .regression_diffusion_landscape import RegressionDiffusionLandscape
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://drive.google.com/u/1/uc?id=1hiBGuJ0OdcHx6XaNEOqttaw_OmculCXY&export=download 505299704 155269ff2291c4b0e975f939c3e2d719c86098b32893eb9087282e0a0ce0a172 DiffusionLandscapeDataset.zip
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""dmdataset dataset."""

import tensorflow_datasets as tfds
import tensorflow as tf

import os
import scipy

_DESCRIPTION = """
"""

_CITATION = """
"""



class RegressionDiffusionLandscape(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("1.0.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
}

def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""

NODE_FEATURES = self.get_features()

return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict(
{
"graph": tfds.features.FeaturesDict(
{
**{
key: tfds.features.Tensor(
shape=feature[0],
dtype=feature[1],
)
for key, feature in NODE_FEATURES[
"graph"
].items()
},
}
),
"labels": tfds.features.FeaturesDict(
{
**{
key: tfds.features.Tensor(
shape=feature[0],
dtype=feature[1],
)
for key, feature in NODE_FEATURES[
"labels"
].items()
},
}
),
"sets": tfds.features.FeaturesDict(
{
**{
key: tfds.features.Tensor(
shape=feature[0],
dtype=feature[1],
)
for key, feature in NODE_FEATURES[
"sets"
].items()
},
}
),
}
),
supervised_keys=None,
homepage="https://dataset-homepage/",
citation=_CITATION,
)

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
path = dl_manager.download_and_extract(
"https://drive.google.com/u/1/uc?id=1hiBGuJ0OdcHx6XaNEOqttaw_OmculCXY&export=download"
)

return {
"train": self._generate_examples(
os.path.join(path, "DiffusionLandscapeDataset", "training")
),
"test": self._generate_examples(
os.path.join(path, "DiffusionLandscapeDataset", "validation")
),
}

def _generate_examples(self, path):
"""Yields examples."""
data = [{}, {}, {}]
for i, subdict in enumerate(self.get_features().values()):
files = (*subdict.keys(),)

for file in files:
data_elem = scipy.sparse.load_npz(
os.path.join(path, file + ".npz")
).toarray()
data_elem = (
data_elem[0] if data_elem.shape[0] == 1 else data_elem
)

data[i][file] = data_elem

yield "key", {
"graph": data[0],
"labels": data[1],
"sets": data[2],
}

def get_features(self):
return {
"graph": {
"frame": [(None, 1), tf.float64],
"node_features": [(None, 3), tf.float64],
"edge_features": [(None, 1), tf.float64],
"edge_indices": [(None, 2), tf.int64],
"edge_dropouts": [(None, 2), tf.float64],
},
"labels": {
"node_labels": [(None,), tf.float64],
"edge_labels": [(None,), tf.float64],
"global_labels": [(None,), tf.float64],
},
"sets": {
"node_sets": [(None, 2), tf.int64],
"edge_sets": [(None, 3), tf.int64],
},
}
27 changes: 27 additions & 0 deletions deeptrack/models/gnns/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ def inner(data):
return inner


def GetSubGraph(num_nodes, node_start):
def inner(data):
graph, labels, *_ = data

edge_connects_removed_node = np.any(
(graph[2] < node_start) | (graph[2] >= node_start + num_nodes),
axis=-1,
)

node_features = graph[0][node_start : node_start + num_nodes]
edge_features = graph[1][~edge_connects_removed_node]
edge_connections = graph[2][~edge_connects_removed_node] - node_start
weights = graph[3][~edge_connects_removed_node]

node_labels = labels[0][node_start : node_start + num_nodes]
edge_labels = labels[1][~edge_connects_removed_node]
global_labels = labels[2]

return (node_features, edge_features, edge_connections, weights), (
node_labels,
edge_labels,
global_labels,
)

return inner


def GetSubGraphFromLabel(samples):
"""
Returns a function that takes a graph and returns a subgraph
Expand Down