Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement online cnmf and register_multisession #332

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions studio/app/common/dataclass/image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import gc
import math
import os
from typing import Optional

import imageio
Expand All @@ -14,6 +16,7 @@
from studio.app.common.dataclass.base import BaseData
from studio.app.common.dataclass.utils import create_images_list
from studio.app.common.schemas.outputs import PlotMetaData
from studio.app.const import MAX_IMAGE_DATA_PART_SIZE
from studio.app.dir_path import DIRPATH


Expand Down Expand Up @@ -47,6 +50,36 @@ def __init__(
del data
gc.collect()

def split_image(self, output_dir: str):
image = self.data
size = image.nbytes
frames = image.shape[0]

if size > MAX_IMAGE_DATA_PART_SIZE:
frames_per_part = math.ceil(
frames / math.ceil(size / MAX_IMAGE_DATA_PART_SIZE)
)
else:
frames_per_part = frames // 2

file_name = self.path[0] if isinstance(self.path, list) else self.path
name, ext = os.path.splitext(os.path.basename(file_name))
save_paths = []

_dir = join_filepath([output_dir, "image_split", name])
create_directory(_dir)

for t in np.arange(0, frames, frames_per_part):
_path = join_filepath([_dir, f"{name}_{t//frames_per_part}{ext}"])
with tifffile.TiffWriter(_path, bigtiff=True) as tif:
if t == frames - 1:
tif.write(image[t:])
else:
tif.write(image[t : t + frames_per_part])
save_paths.append(_path)

return save_paths

@property
def data(self):
if isinstance(self.path, list):
Expand Down
2 changes: 2 additions & 0 deletions studio/app/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ class FILETYPE:
NOT_DISPLAY_ARGS_LIST = ["params", "output_dir", "nwbfile", "kwargs"]

DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

MAX_IMAGE_DATA_PART_SIZE = 1_000_000_000 # 1GB
21 changes: 13 additions & 8 deletions studio/app/optinist/core/nwb/nwb_creater.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,19 @@ def fluorescence(cls, nwbfile, function_id, roi_list):
roi["table_name"], region=roi["region"]
)

roi_resp_series = RoiResponseSeries(
name=roi["name"],
data=roi["data"],
rois=region_roi,
unit=roi["unit"],
timestamps=roi.get("timestamps"),
rate=float(roi.get("rate", 0.0)),
)
roi_resp_dict = {
"name": roi["name"],
"data": roi["data"],
"rois": region_roi,
"unit": roi["unit"],
"timestamps": roi.get("timestamps"),
"rate": float(roi.get("rate", 0.0)),
}
if "comments" in roi:
roi_resp_dict["comments"] = roi["comments"]

roi_resp_series = RoiResponseSeries(**roi_resp_dict)

fluo.add_roi_response_series(roi_resp_series)

nwbfile.processing["ophys"].add(fluo)
Expand Down
7 changes: 7 additions & 0 deletions studio/app/optinist/wrappers/caiman/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from studio.app.optinist.wrappers.caiman.cnmf import caiman_cnmf
from studio.app.optinist.wrappers.caiman.cnmf_multisession import (
caiman_cnmf_multisession,
)
from studio.app.optinist.wrappers.caiman.cnmfe import caiman_cnmfe
from studio.app.optinist.wrappers.caiman.motion_correction import caiman_mc

Expand All @@ -16,5 +19,9 @@
"function": caiman_cnmfe,
"conda_name": "caiman",
},
"cnmf_multisession": {
"function": caiman_cnmf_multisession,
"conda_name": "caiman",
},
}
}
79 changes: 51 additions & 28 deletions studio/app/optinist/wrappers/caiman/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,36 @@ def get_roi(A, roi_thr, thr_method, swap_dim, dims):
return ims


def util_get_memmap(images: np.ndarray, file_path: str):
"""
convert np.ndarray to mmap
"""
from caiman.mmapping import prepare_shape
from caiman.paths import memmap_frames_filename

order = "C"
dims = images.shape[1:]
T = images.shape[0]
shape_mov = (np.prod(dims), T)

dir_path = join_filepath(file_path.split("/")[:-1])
basename = file_path.split("/")[-1]
fname_tot = memmap_frames_filename(basename, dims, T, order)
mmap_path = join_filepath([dir_path, fname_tot])

mmap_images = np.memmap(
mmap_path,
mode="w+",
dtype=np.float32,
shape=prepare_shape(shape_mov),
order=order,
)

mmap_images = np.reshape(mmap_images.T, [T] + list(dims), order="F")
mmap_images[:] = images[:]
return mmap_images, dims, mmap_path


def util_recursive_flatten_params(params, result_params: dict, nest_counter=0):
"""
Recursively flatten node parameters (operation for CaImAn CNMFParams)
Expand All @@ -89,9 +119,7 @@ def caiman_cnmf(
) -> dict(fluorescence=FluoData, iscell=IscellData):
from caiman import local_correlations, stop_server
from caiman.cluster import setup_cluster
from caiman.mmapping import prepare_shape
from caiman.paths import memmap_frames_filename
from caiman.source_extraction.cnmf import cnmf
from caiman.source_extraction.cnmf import cnmf, online_cnmf
from caiman.source_extraction.cnmf.params import CNMFParams

function_id = output_dir.split("/")[-1]
Expand All @@ -111,33 +139,14 @@ def caiman_cnmf(
Ain = reshaped_params.pop("Ain", None)
do_refit = reshaped_params.pop("do_refit", None)
roi_thr = reshaped_params.pop("roi_thr", None)
use_online = reshaped_params.pop("use_online", False)

file_path = images.path
if isinstance(file_path, list):
file_path = file_path[0]

images = images.data

# np.arrayをmmapへ変換
order = "C"
dims = images.shape[1:]
T = images.shape[0]
shape_mov = (np.prod(dims), T)

dir_path = join_filepath(file_path.split("/")[:-1])
basename = file_path.split("/")[-1]
fname_tot = memmap_frames_filename(basename, dims, T, order)

mmap_images = np.memmap(
join_filepath([dir_path, fname_tot]),
mode="w+",
dtype=np.float32,
shape=prepare_shape(shape_mov),
order=order,
)

mmap_images = np.reshape(mmap_images.T, [T] + list(dims), order="F")
mmap_images[:] = images[:]
mmap_images, dims, mmap_path = util_get_memmap(images, file_path)

del images
gc.collect()
Expand All @@ -157,11 +166,25 @@ def caiman_cnmf(
backend="local", n_processes=None, single_thread=True
)

cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, Ain=Ain, params=ops)
cnm = cnm.fit(mmap_images)
if use_online:
ops.change_params(
{
"fnames": [mmap_path],
# NOTE: These params uses np.inf as default in CaImAn.
# Yaml cannot serialize np.inf, so default value in yaml is None.
"max_comp_update_shape": reshaped_params["max_comp_update_shape"]
or np.inf,
"num_times_comp_updated": reshaped_params["update_num_comps"] or np.inf,
}
)
cnm = online_cnmf.OnACID(dview=dview, Ain=Ain, params=ops)
cnm.fit_online()
else:
cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, Ain=Ain, params=ops)
cnm = cnm.fit(mmap_images)

if do_refit:
cnm = cnm.refit(mmap_images, dview=dview)
if do_refit:
cnm = cnm.refit(mmap_images, dview=dview)

cnm.estimates.evaluate_components(mmap_images, cnm.params, dview=dview)

Expand Down
Loading
Loading