Skip to content

Commit

Permalink
Merge pull request #332 from arayabrain/feature/caiman-memory-usage
Browse files Browse the repository at this point in the history
implement online cnmf and register_multisession
  • Loading branch information
ReiHashimoto committed Apr 12, 2024
2 parents 49a1046 + 0d21e48 commit 1986a6e
Show file tree
Hide file tree
Showing 8 changed files with 475 additions and 37 deletions.
33 changes: 33 additions & 0 deletions studio/app/common/dataclass/image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import gc
import math
import os
from typing import Optional

import imageio
Expand All @@ -14,6 +16,7 @@
from studio.app.common.dataclass.base import BaseData
from studio.app.common.dataclass.utils import create_images_list
from studio.app.common.schemas.outputs import PlotMetaData
from studio.app.const import MAX_IMAGE_DATA_PART_SIZE
from studio.app.dir_path import DIRPATH


Expand Down Expand Up @@ -47,6 +50,36 @@ def __init__(
del data
gc.collect()

def split_image(self, output_dir: str):
image = self.data
size = image.nbytes
frames = image.shape[0]

if size > MAX_IMAGE_DATA_PART_SIZE:
frames_per_part = math.ceil(
frames / math.ceil(size / MAX_IMAGE_DATA_PART_SIZE)
)
else:
frames_per_part = frames // 2

file_name = self.path[0] if isinstance(self.path, list) else self.path
name, ext = os.path.splitext(os.path.basename(file_name))
save_paths = []

_dir = join_filepath([output_dir, "image_split", name])
create_directory(_dir)

for t in np.arange(0, frames, frames_per_part):
_path = join_filepath([_dir, f"{name}_{t//frames_per_part}{ext}"])
with tifffile.TiffWriter(_path, bigtiff=True) as tif:
if t == frames - 1:
tif.write(image[t:])
else:
tif.write(image[t : t + frames_per_part])
save_paths.append(_path)

return save_paths

@property
def data(self):
if isinstance(self.path, list):
Expand Down
2 changes: 2 additions & 0 deletions studio/app/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ class FILETYPE:
NOT_DISPLAY_ARGS_LIST = ["params", "output_dir", "nwbfile", "kwargs"]

DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

MAX_IMAGE_DATA_PART_SIZE = 1_000_000_000 # 1GB
21 changes: 13 additions & 8 deletions studio/app/optinist/core/nwb/nwb_creater.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,19 @@ def fluorescence(cls, nwbfile, function_id, roi_list):
roi["table_name"], region=roi["region"]
)

roi_resp_series = RoiResponseSeries(
name=roi["name"],
data=roi["data"],
rois=region_roi,
unit=roi["unit"],
timestamps=roi.get("timestamps"),
rate=float(roi.get("rate", 0.0)),
)
roi_resp_dict = {
"name": roi["name"],
"data": roi["data"],
"rois": region_roi,
"unit": roi["unit"],
"timestamps": roi.get("timestamps"),
"rate": float(roi.get("rate", 0.0)),
}
if "comments" in roi:
roi_resp_dict["comments"] = roi["comments"]

roi_resp_series = RoiResponseSeries(**roi_resp_dict)

fluo.add_roi_response_series(roi_resp_series)

nwbfile.processing["ophys"].add(fluo)
Expand Down
7 changes: 7 additions & 0 deletions studio/app/optinist/wrappers/caiman/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from studio.app.optinist.wrappers.caiman.cnmf import caiman_cnmf
from studio.app.optinist.wrappers.caiman.cnmf_multisession import (
caiman_cnmf_multisession,
)
from studio.app.optinist.wrappers.caiman.cnmfe import caiman_cnmfe
from studio.app.optinist.wrappers.caiman.motion_correction import caiman_mc

Expand All @@ -16,5 +19,9 @@
"function": caiman_cnmfe,
"conda_name": "caiman",
},
"cnmf_multisession": {
"function": caiman_cnmf_multisession,
"conda_name": "caiman",
},
}
}
79 changes: 51 additions & 28 deletions studio/app/optinist/wrappers/caiman/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,36 @@ def get_roi(A, roi_thr, thr_method, swap_dim, dims):
return ims


def util_get_memmap(images: np.ndarray, file_path: str):
"""
convert np.ndarray to mmap
"""
from caiman.mmapping import prepare_shape
from caiman.paths import memmap_frames_filename

order = "C"
dims = images.shape[1:]
T = images.shape[0]
shape_mov = (np.prod(dims), T)

dir_path = join_filepath(file_path.split("/")[:-1])
basename = file_path.split("/")[-1]
fname_tot = memmap_frames_filename(basename, dims, T, order)
mmap_path = join_filepath([dir_path, fname_tot])

mmap_images = np.memmap(
mmap_path,
mode="w+",
dtype=np.float32,
shape=prepare_shape(shape_mov),
order=order,
)

mmap_images = np.reshape(mmap_images.T, [T] + list(dims), order="F")
mmap_images[:] = images[:]
return mmap_images, dims, mmap_path


def util_recursive_flatten_params(params, result_params: dict, nest_counter=0):
"""
Recursively flatten node parameters (operation for CaImAn CNMFParams)
Expand All @@ -89,9 +119,7 @@ def caiman_cnmf(
) -> dict(fluorescence=FluoData, iscell=IscellData):
from caiman import local_correlations, stop_server
from caiman.cluster import setup_cluster
from caiman.mmapping import prepare_shape
from caiman.paths import memmap_frames_filename
from caiman.source_extraction.cnmf import cnmf
from caiman.source_extraction.cnmf import cnmf, online_cnmf
from caiman.source_extraction.cnmf.params import CNMFParams

function_id = output_dir.split("/")[-1]
Expand All @@ -111,33 +139,14 @@ def caiman_cnmf(
Ain = reshaped_params.pop("Ain", None)
do_refit = reshaped_params.pop("do_refit", None)
roi_thr = reshaped_params.pop("roi_thr", None)
use_online = reshaped_params.pop("use_online", False)

file_path = images.path
if isinstance(file_path, list):
file_path = file_path[0]

images = images.data

# np.arrayをmmapへ変換
order = "C"
dims = images.shape[1:]
T = images.shape[0]
shape_mov = (np.prod(dims), T)

dir_path = join_filepath(file_path.split("/")[:-1])
basename = file_path.split("/")[-1]
fname_tot = memmap_frames_filename(basename, dims, T, order)

mmap_images = np.memmap(
join_filepath([dir_path, fname_tot]),
mode="w+",
dtype=np.float32,
shape=prepare_shape(shape_mov),
order=order,
)

mmap_images = np.reshape(mmap_images.T, [T] + list(dims), order="F")
mmap_images[:] = images[:]
mmap_images, dims, mmap_path = util_get_memmap(images, file_path)

del images
gc.collect()
Expand All @@ -157,11 +166,25 @@ def caiman_cnmf(
backend="local", n_processes=None, single_thread=True
)

cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, Ain=Ain, params=ops)
cnm = cnm.fit(mmap_images)
if use_online:
ops.change_params(
{
"fnames": [mmap_path],
# NOTE: These params uses np.inf as default in CaImAn.
# Yaml cannot serialize np.inf, so default value in yaml is None.
"max_comp_update_shape": reshaped_params["max_comp_update_shape"]
or np.inf,
"num_times_comp_updated": reshaped_params["update_num_comps"] or np.inf,
}
)
cnm = online_cnmf.OnACID(dview=dview, Ain=Ain, params=ops)
cnm.fit_online()
else:
cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, Ain=Ain, params=ops)
cnm = cnm.fit(mmap_images)

if do_refit:
cnm = cnm.refit(mmap_images, dview=dview)
if do_refit:
cnm = cnm.refit(mmap_images, dview=dview)

cnm.estimates.evaluate_components(mmap_images, cnm.params, dview=dview)

Expand Down
Loading

0 comments on commit 1986a6e

Please sign in to comment.