Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 106 additions & 51 deletions python/ouroboros/helpers/mem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import astuple, replace, asdict, fields
from functools import cached_property
from multiprocessing.shared_memory import SharedMemory, ShareableList
from multiprocessing.managers import SharedMemoryManager, BaseManager
from multiprocessing.managers import SharedMemoryManager, BaseManager, ListProxy
from sys import stdout
from time import sleep
from typing import TextIO
Expand All @@ -20,42 +20,6 @@ def is_advanced_index(index):
return isinstance(index, (list, np.ndarray)) and not isinstance(index, slice)


class SharedNPManager(SharedMemoryManager):
""" Manages shared memory numpy arrays. """
def __init__(self, *args,
queue_mem: list[tuple[DataShape, np.dtype] | tuple[tuple[DataShape, np.dtype]]] = [],
**kwargs):
SharedMemoryManager.__init__(self, *args, **kwargs)

self.__mem_queue = []
for mem in queue_mem:
if isinstance(mem[0], tuple):
self.__mem_queue.append([mem[0][0], mem[0][1]] + list(mem[1:]))
else:
self.__mem_queue.append([mem[0], mem[1]])

def SharedNPArray(self, shape: DataShape, dtype: np.dtype, *create_with: tuple[DataShape, np.dtype]):
full_set = [(shape, dtype)] + list(create_with)
size = max([np.prod(astuple(shape), dtype=object) * np.dtype(dtype).itemsize for (shape, dtype) in full_set])
mem = self.SharedMemory(int(size))
result = [SharedNPArray(mem.name, shape, dtype) for (shape, dtype) in full_set]
return result[0] if len(result) == 1 else result

def clear_queue(self):
ar_mem = []
while len(self.__mem_queue) > 0:
new_mem = self.SharedNPArray(*self.__mem_queue.pop(0))
ar_mem += new_mem if isinstance(new_mem, list) else [new_mem]
return ar_mem

def __enter__(self):
this = [BaseManager.__enter__(self)]
return tuple(this + self.clear_queue())

def __exit__(self, *args, **kwargs):
super().__exit__(*args, **kwargs)


class SharedNPArray:
def __init__(self, name: str, shape: DataShape, dtype: np.dtype,
views: list = None, *, allocate: bool = False):
Expand Down Expand Up @@ -148,21 +112,95 @@ def shape(self):
return shape


def cleanup_mem(*shm_objects):
""" Close and unlink shared memory objects.
_termed_mem = []

:param shm_objects: Shared memory objects to shut down.
"""
for shm in shm_objects:
if isinstance(shm, SharedNPArray):
shm.shutdown()
del shm
elif isinstance(shm, ShareableList):
shm.shm.close()
shm.shm.unlink()
elif isinstance(shm, SharedMemory):
shm.close()
shm.unlink()

def get_termed_mem():
"""Returns the existing global list rather than creating a new one."""
return _termed_mem


class SharedNPManager(SharedMemoryManager):
SharedMemoryManager.register(
'_TermedMem',
callable=get_termed_mem,
proxytype=ListProxy
)

""" Manages shared memory numpy arrays. """
def __init__(self, *args,
queue_mem: list[tuple[DataShape, np.dtype] | tuple[tuple[DataShape, np.dtype]]] = [],
**kwargs):
SharedMemoryManager.__init__(self, *args, **kwargs)

self.__mem_queue = []
for mem in queue_mem:
if isinstance(mem[0], tuple):
self.__mem_queue.append([mem[0][0], mem[0][1]] + list(mem[1:]))
else:
self.__mem_queue.append([mem[0], mem[1]])
self.__termed_mem = None

def SharedNPArray(self, shape: DataShape, dtype: np.dtype, *create_with: tuple[DataShape, np.dtype]):
full_set = [(shape, dtype)] + list(create_with)
size = max([np.prod(astuple(shape), dtype=object) * np.dtype(dtype).itemsize for (shape, dtype) in full_set])
mem = self.SharedMemory(int(size))
result = [SharedNPArray(mem.name, shape, dtype) for (shape, dtype) in full_set]
return result[0] if len(result) == 1 else result

def TermedNPArray(self, shape: DataShape, dtype: np.dtype, *create_with: tuple[DataShape, np.dtype]):
full_set = [(shape, dtype)] + list(create_with)
size = max([np.prod(astuple(shape), dtype=object) * np.dtype(dtype).itemsize for (shape, dtype) in full_set])
mem = SharedMemory(create=True, size=int(size))
result = [SharedNPArray(mem.name, shape, dtype) for (shape, dtype) in full_set]
self.__termed_mem.append(mem.name)
return result[0] if len(result) == 1 else result

def clear_queue(self):
ar_mem = []
while len(self.__mem_queue) > 0:
new_mem = self.SharedNPArray(*self.__mem_queue.pop(0))
ar_mem += new_mem if isinstance(new_mem, list) else [new_mem]
return ar_mem

def remove_termed(self, mem):
if isinstance(mem, SharedNPArray):
name = mem.name
mem.shutdown()
else:
name = mem
if name in self.__termed_mem:
self.__termed_mem.pop(self.__termed_mem.index(name))
t = SharedMemory(name)
t.close()
t.unlink()
else:
raise FileNotFoundError(f"{name} is not a termed shared memory array. {self.__termed_mem}")

def shutdown(self):
for name in self.__termed_mem:
t = SharedMemory(name)
t.close()
t.unlink()
super().shutdown()

def start(self, *args, **kwargs):
super().start(*args, **kwargs)
# Initialize the proxy immediately upon start
self.__termed_mem = self._TermedMem()

def connect(self):
super().connect()
# Initialize the proxy immediately upon connect
self.__termed_mem = self._TermedMem()

def __enter__(self):
this = [BaseManager.__enter__(self)]
return tuple(this + self.clear_queue())

def __exit__(self, *args, **kwargs):
print("Exiting! SHM!")
super().__exit__(*args, **kwargs)


def exit_cleanly(step: str, *shm_objects, return_code: int = 0, statement: str = '', log_level: LOG = LOG.TIME,
Expand Down Expand Up @@ -191,3 +229,20 @@ def mem_monitor(mem_file, mem_store, pid):
last_step = last_step_arr.tobytes().decode()
log.write(last_step, out=out, pid=pid)
sleep(MEM_INTERVAL_TIMER)


def cleanup_mem(*shm_objects):
""" Close and unlink shared memory objects.

:param shm_objects: Shared memory objects to shut down.
"""
for shm in shm_objects:
if isinstance(shm, SharedNPArray):
shm.shutdown()
del shm
elif isinstance(shm, ShareableList):
shm.shm.close()
shm.shm.unlink()
elif isinstance(shm, SharedMemory):
shm.close()
shm.unlink()
108 changes: 41 additions & 67 deletions python/ouroboros/helpers/volume_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import astuple
import os
import sys
import traceback
import time

from cloudvolume import CloudVolume, VolumeCutout, Bbox
import numpy as np
Expand Down Expand Up @@ -37,10 +36,7 @@ def __init__(
self.volumes = [None] * len(bounding_boxes)

self.use_shared = use_shared

if self.use_shared:
self.shm_host = SharedNPManager()
self.shm_host.__enter__()
self.__shm_host = None

# Indicates whether the a volume should be cached after the last slice to request it is processed
self.cache_volume = [False] * len(bounding_boxes)
Expand All @@ -59,6 +55,11 @@ def to_dict(self) -> dict:
"flush_cache": self.flush_cache,
}

def connect_shm(self, address: str, authkey: str):
self.__shm_host = SharedNPManager(address=address, authkey=authkey)
self.__shm_host.connect()
self.__authkey = authkey

@staticmethod
def from_dict(data: dict) -> "VolumeCache":
bounding_boxes = [BoundingBox.from_dict(bb) for bb in data["bounding_boxes"]]
Expand Down Expand Up @@ -131,7 +132,7 @@ def request_volume_for_slice(self, slice_index: int):

# Download the volume if it is not already cached
if self.volumes[vol_index] is None:
self.download_volume(vol_index, bounding_box)
self.volumes[vol_index] = download_volume(self.cv, bounding_box, mip=self.mip)

# Remove the last requested volume if it is not to be cached
if (
Expand All @@ -153,68 +154,9 @@ def remove_volume(self, volume_index: int, destroy_shared: bool = False):
if not self.use_shared:
self.volumes[volume_index] = None
elif destroy_shared:
self.volumes[volume_index].shutdown()
self.__shm_host.remove_termed(self.volumes[volume_index])
self.volumes[volume_index] = None

def download_volume(
self, volume_index: int, bounding_box: BoundingBox, parallel=False
) -> VolumeCutout:
bbox = bounding_box.to_cloudvolume_bbox().astype(int)
vol_shape = NGOrder(*bbox.size3(), self.cv.cv.num_channels)

# Limit size of area we are grabbing, in case we go out of bounds.
dl_box = Bbox.intersection(self.cv.cv.bounds, bbox)
local_min = [int(start) for start in np.subtract(dl_box.minpt, bbox.minpt)]

local_bounds = np.s_[*[slice(start, stop) for start, stop in
zip(local_min, np.sum([local_min, dl_box.size3()], axis=0))],
:]

# Download the bounding box volume
if self.use_shared:
volume = self.shm_host.SharedNPArray(vol_shape, np.float32)
with volume as volume_data:
volume_data[:] = 0 # Prob not most efficient but makes math much easier
volume_data[local_bounds] = self.cv.cv.download(dl_box, mip=self.mip, parallel=parallel)
else:
volume = np.zeros(astuple(vol_shape))
volume[local_bounds] = self.cv.cv.download(dl_box, mip=self.mip, parallel=parallel)

# Store the volume in the cache
self.volumes[volume_index] = volume

def create_processing_data(self, volume_index: int, parallel=False):
"""
Generate a data packet for processing a volume.

Suitable for parallel processing.

Parameters:
----------
volume_index (int): The index of the volume to process.
parallel (bool): Whether to download the volume in parallel (only do parallel if downloading in one thread).

Returns:
-------
tuple: A tuple containing the volume data, the bounding box of the volume,
the slice indices associated with the volume, and a function to remove the volume from the cache.
"""

bounding_box = self.bounding_boxes[volume_index]

# Download the volume if it is not already cached
if self.volumes[volume_index] is None:
try:
self.download_volume(volume_index, bounding_box, parallel=parallel)
except BaseException as be:
traceback.print_tb(be.__traceback__, file=sys.stderr)
return f"Error downloading data: {be}"

# Get all slice indices associated with this volume
slice_indices = self.get_slice_indices(volume_index)

return self.volumes[volume_index], bounding_box, slice_indices, volume_index

def get_slice_indices(self, volume_index: int):
return [i for i, v in enumerate(self.link_rects) if v == volume_index]

Expand Down Expand Up @@ -264,6 +206,38 @@ def flush_cache(self):
self.cv.cache.flush()


def download_volume(
cv: CloudVolumeInterface, bounding_box: BoundingBox, mip, parallel=False,
use_shared=False, shm_address: str = None, shm_authkey: str = None, **kwargs
) -> VolumeCutout:
start = time.perf_counter()
bbox = bounding_box.to_cloudvolume_bbox().astype(int)
vol_shape = NGOrder(*bbox.size3(), cv.cv.num_channels)

# Limit size of area we are grabbing, in case we go out of bounds.
dl_box = Bbox.intersection(cv.cv.bounds, bbox)
local_min = [int(start) for start in np.subtract(dl_box.minpt, bbox.minpt)]

local_bounds = np.s_[*[slice(start, stop) for start, stop in
zip(local_min, np.sum([local_min, dl_box.size3()], axis=0))],
:]

# Download the bounding box volume
if use_shared:
shm_host = SharedNPManager(address=shm_address, authkey=shm_authkey)
shm_host.connect()
volume = shm_host.TermedNPArray(vol_shape, np.float32)
with volume as volume_data:
volume_data[:] = 0 # Prob not most efficient but makes math much easier
volume_data[local_bounds] = cv.cv.download(dl_box, mip=mip, parallel=parallel)
else:
volume = np.zeros(astuple(vol_shape))
volume[local_bounds] = cv.cv.download(dl_box, mip=mip, parallel=parallel)

# Return volume
return volume, bounding_box, time.perf_counter() - start, *kwargs.values()


def get_mip_volume_sizes(source_url: str) -> dict:
"""
Get the volume sizes for all available MIPs.
Expand Down
Loading