In [1]:
#Yu Yamaoka
#crop段階でフォルダPathの指定が関数内部で行ってしまっている点に注意．

#Parameter Define
model_type = 'cyto'#https://github.com/MouseLand/cellpose/blob/main/cellpose/models.py#L19~L20
chan = [0,0]#チャンネル https://github.com/MouseLand/cellpose/blob/main/cellpose/models.py#L209

In [2]:
#Function Define
from cellpose import models, io
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os

# DEFINE CELLPOSE MODEL
# model_type='cyto' or model_type='nuclei'
def img_to_cellpose(img_path, model_type, chan):
    """
    Input:
        img_path : (string) Image file PATH
    Return:
        mask : [width height]
    """
    model = models.Cellpose(gpu=False, model_type=model_type)
    img = io.imread(img_path)
    mask, flows, styles, diams = model.eval(img, diameter=None, channels=chan)

    # save results so you can load in gui
    #io.masks_flows_to_seg(img, masks, flows, diams, img_path, chan)

    # save results as png
    #plt.imsave("test.png",masks)

    return mask

#mask画像をMaskRCNNが読み込めるデータセットにする。
def obj_detection(mask, class_id:int):
    """
    Input:
        mask : [width, height](ndarray), image data
        class_id : int , class id(ex : 1day -> 1)
    Return:
        mask : [width, height, n], n is object num.
        cls_idxs : [n(int)]
    """
    data = mask
    labels = []
    for label in np.unique(data):
        #: ラベルID==0は背景
        if label == 0:
            continue
        else:
            labels.append(label)

    if len(labels) == 0:
        #: 対象オブジェクトがない場合はNone
        return None, None
    else:
        mask = np.zeros((mask.shape)+(len(labels),), dtype=np.uint8)
        for n, label in enumerate(labels):
            mask[:, :, n] = np.uint8(data == label)
        cls_idxs = np.ones([mask.shape[-1]], dtype=np.int32) * class_id

        return mask, cls_idxs

def mask_to_patch(mask, img_path, size=32):
    """
    Input:
        mask : [n(objnum), width, height], n is object num.
    Return:
        crop_imgs :  [n(objnum), width, height, color]
    """
    #各mask-objのBBOXの重心点を求めて切り抜き
    img = cv2.imread(img_path)
    w, h, _ = img.shape
    crop_imgs = np.zeros((size, size, 3)+(len(mask),), dtype=np.uint8)
    
    for i in range(len(mask)):
        pos = np.where(mask[i]==1)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        
        x1 = (xmin + xmax)//2 - size//2
        x2 = (xmin + xmax)//2 + size//2
        y1 = (ymin + ymax)//2 - size//2
        y2 = (ymin + ymax)//2 + size//2
        
        if x1<0:
            x1 = 0
            x2 = (size//2) * 2
        if x2>w:#配列外アクセスを防ぐ
            x1 = w - ((size//2)  * 2 + 1)
            x2 = w - 1
        
        if y1<0:
            y1 = 0
            y2 = (size//2)  * 2
        if y2>h:#配列外アクセスを防ぐ
            y1 = h - ((size//2)  * 2 + 1)
            y2 = h - 1
        
        crop_img = img[y1:y2, x1:x2]
        crop_imgs[:, :, :, i] = crop_img
    
        #filename  = os.path.basename(img_path)
        #if(os.path.exists('crop_data3232/day' +str(day))==False):
        #    os.mkdir('crop_data3232/day' +str(day))
        #save_path = os.path.join('crop_data3232/day' +str(day), str(i) + "_" + filename)
        #cv2.imwrite(save_path, crop_img)
    crop_imgs = crop_imgs.transpose(3, 0 , 1, 2)
    #print(crop_imgs.shape)
    return crop_imgs
   

In [3]:
#main
from glob import glob
from tqdm import tqdm

days = ["0", "3", "5", "7"]

for day in days:
    files = glob("../data/cut_patch512512/"+str(day)+"day/*.png")
    for file in tqdm(files):
        try:
            test_path = file
            mask = img_to_cellpose(test_path, model_type, chan)
            mask, _ = obj_detection(mask, 1)
            mask_trans = mask.transpose(2, 0 ,1)
            crop_imgs = mask_to_patch(mask_trans, test_path, size=32)
            
            #save
            for i, crop_img in enumerate(crop_imgs):
            
                filename  = os.path.basename(file)
                if(os.path.exists('crop_data3232/day' +str(day))==False):
                    os.mkdir('crop_data3232/day' +str(day))
                save_path = os.path.join('crop_data3232/day' +str(day), str(i) + "_" + filename)
                cv2.imwrite(save_path, crop_img)
        except:
            print(file)


NameError: name 'day' is not defined