Skip to content

Commit

Permalink
adding tile_overlap as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Nov 23, 2020
1 parent 5669246 commit d6cd00f
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 60 deletions.
109 changes: 76 additions & 33 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, gpu=False, model_type='cyto', net_avg=True, device=None):
self.sz.model_type = model_type

def eval(self, x, batch_size=8, channels=None, invert=False, normalize=True, diameter=30., do_3D=False, anisotropy=None,
net_avg=True, augment=False, tile=True, resample=False, flow_threshold=0.4, cellprob_threshold=0.0,
net_avg=True, augment=False, tile=True, tile_overlap=0.1, resample=False, flow_threshold=0.4, cellprob_threshold=0.0,
min_size=15, stitch_threshold=0.0, rescale=None, progress=None):
""" run cellpose and get masks
Expand Down Expand Up @@ -152,6 +152,12 @@ def eval(self, x, batch_size=8, channels=None, invert=False, normalize=True, dia
tile: bool (optional, default True)
tiles image to ensure GPU/CPU memory usage limited (recommended)
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
resample: bool (optional, default False)
run dynamics at original image size (will be slower but create more accurate boundaries)
flow_threshold: float (optional, default 0.4)
flow error threshold (all cells with errors below threshold are kept) (not used for 3D)
Expand Down Expand Up @@ -239,13 +245,22 @@ def eval(self, x, batch_size=8, channels=None, invert=False, normalize=True, dia
diams = self.diam_mean / rescale

tic = time.time()
masks, flows, styles = self.cp.eval(x, batch_size=batch_size, invert=invert, rescale=rescale, anisotropy=anisotropy,
channels=channels, augment=augment, tile=tile, do_3D=do_3D,
masks, flows, styles = self.cp.eval(x,
batch_size=batch_size,
invert=invert,
rescale=rescale,
anisotropy=anisotropy,
channels=channels,
augment=augment,
tile=tile,
do_3D=do_3D,
net_avg=net_avg, progress=progress,
tile_overlap=tile_overlap,
resample=resample,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold,
min_size=min_size, stitch_threshold=stitch_threshold)
min_size=min_size,
stitch_threshold=stitch_threshold)
print('estimated masks for %d image(s) in %0.2f sec'%(nimg, time.time()-tic))
print('>>>> TOTAL TIME %0.2f sec'%(time.time()-tic0))

Expand Down Expand Up @@ -455,7 +470,7 @@ def eval(self, x, batch_size=8, channels=None, invert=False, normalize=True,



def _run_nets(self, img, net_avg=True, augment=False, tile=True, bsize=224, progress=None):
def _run_nets(self, img, net_avg=True, augment=False, tile=True, tile_overlap=0.1, bsize=224, progress=None):
""" run network (if more than one, loop over networks and average results
Parameters
Expand All @@ -472,6 +487,9 @@ def _run_nets(self, img, net_avg=True, augment=False, tile=True, bsize=224, prog
tile: bool (optional, default True)
tiles image to ensure GPU memory usage limited (recommended)
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
progress: pyqt progress bar (optional, default None)
to return progress bar status to GUI
Expand All @@ -494,7 +512,8 @@ def _run_nets(self, img, net_avg=True, augment=False, tile=True, bsize=224, prog
for j in range(len(self.pretrained_model)):
self.net.load_parameters(self.pretrained_model[j])
self.net.collect_params().grad_req = 'null'
y0, style = self._run_net(img, augment=augment, tile=tile, bsize=bsize)
y0, style = self._run_net(img, augment=augment, tile=tile,
tile_overlap=tile_overlap, bsize=bsize)

if j==0:
y = y0
Expand All @@ -505,7 +524,7 @@ def _run_nets(self, img, net_avg=True, augment=False, tile=True, bsize=224, prog
y = y / len(self.pretrained_model)
return y, style

def _run_net(self, imgs, augment=False, tile=True, bsize=224):
def _run_net(self, imgs, augment=False, tile=True, tile_overlap=0.1, bsize=224):
""" run network on image or stack of images
(faster if augment is False)
Expand All @@ -525,6 +544,9 @@ def _run_net(self, imgs, augment=False, tile=True, bsize=224):
tiles image to ensure GPU/CPU memory usage limited (recommended);
cannot be turned off for 3D segmentation
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
bsize: int (optional, default 224)
size of tiles to use in pixels [bsize x bsize]
Expand Down Expand Up @@ -558,7 +580,7 @@ def _run_net(self, imgs, augment=False, tile=True, bsize=224):

# run network
if tile or augment or imgs.ndim==4:
y,style = self._run_tiled(imgs, augment=augment, bsize=bsize)
y,style = self._run_tiled(imgs, augment=augment, bsize=bsize, tile_overlap=tile_overlap)
else:
imgs = nd.array(np.expand_dims(imgs, axis=0), ctx=self.device)
y,style = self.net(imgs)
Expand All @@ -575,7 +597,7 @@ def _run_net(self, imgs, augment=False, tile=True, bsize=224):

return y, style

def _run_tiled(self, imgi, augment=False, bsize=224):
def _run_tiled(self, imgi, augment=False, bsize=224, tile_overlap=0.1):
""" run network in tiles of size [bsize x bsize]
First image is split into overlapping tiles of size [bsize x bsize].
Expand All @@ -592,6 +614,9 @@ def _run_tiled(self, imgi, augment=False, bsize=224):
bsize: int (optional, default 224)
size of tiles to use in pixels [bsize x bsize]
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
Returns
------------------
Expand All @@ -606,37 +631,42 @@ def _run_tiled(self, imgi, augment=False, bsize=224):
"""

if imgi.ndim==4:
batch_size = self.batch_size
Lz, nchan = imgi.shape[:2]
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi[0], bsize=bsize, augment=augment)
ny, nx, nchan = IMG.shape[:3]
yf = np.zeros((Lz, self.nclasses, imgi.shape[2], imgi.shape[3]), np.float32)
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi[0], bsize=bsize,
augment=augment, tile_overlap=tile_overlap)
ny, nx, nchan, ly, lx = IMG.shape
batch_size *= max(4, (bsize**2 // (ly*lx))**0.5)
yf = np.zeros((Lz, self.nclasses, imgi.shape[-2], imgi.shape[-1]), np.float32)
styles = []
if ny*nx > self.batch_size:
if ny*nx > batch_size:
ziterator = trange(Lz)
for i in ziterator:
yfi, stylei = self._run_tiled(imgi[i], augment=augment, bsize=bsize)
yfi, stylei = self._run_tiled(imgi[i], augment=augment,
bsize=bsize, tile_overlap=tile_overlap)
yf[i] = yfi
styles.append(stylei)
else:
# run multiple slices at the same time
ntiles = ny*nx
nimgs = max(2, int(np.round(self.batch_size / ntiles)))
nimgs = max(2, int(np.round(batch_size / ntiles)))
niter = int(np.ceil(Lz/nimgs))
ziterator = trange(niter)
for k in ziterator:
IMGa = np.zeros((ntiles*nimgs, nchan, bsize, bsize), np.float32)
IMGa = np.zeros((ntiles*nimgs, nchan, ly, lx), np.float32)
for i in range(min(Lz-k*nimgs, nimgs)):
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi[k*nimgs+i], bsize=bsize, augment=augment)
IMGa[i*ntiles:(i+1)*ntiles] = np.reshape(IMG, (ny*nx, nchan, bsize, bsize))
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi[k*nimgs+i], bsize=bsize,
augment=augment, tile_overlap=tile_overlap)
IMGa[i*ntiles:(i+1)*ntiles] = np.reshape(IMG, (ny*nx, nchan, ly, lx))
y0, style = self.net(nd.array(IMGa, ctx=self.device))
ya = y0.asnumpy()
stylea = style.asnumpy()
for i in range(min(Lz-k*nimgs, nimgs)):
y = ya[i*ntiles:(i+1)*ntiles]
if augment:
y = np.reshape(y, (ny, nx, 3, bsize, bsize))
y = np.reshape(y, (ny, nx, 3, ly, lx))
y = transforms.unaugment_tiles(y, self.unet)
y = np.reshape(y, (-1, 3, bsize, bsize))
y = np.reshape(y, (-1, 3, ly, lx))
yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx)
yfi = yfi[:,:imgi.shape[2],:imgi.shape[3]]
yf[k*nimgs+i] = yfi
Expand All @@ -645,21 +675,21 @@ def _run_tiled(self, imgi, augment=False, bsize=224):
styles.append(stylei)
return yf, np.array(styles)
else:
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi, bsize=bsize, augment=augment)
ny, nx, nchan = IMG.shape[:3]
IMG = np.reshape(IMG, (ny*nx, nchan, bsize, bsize))
nbatch = self.batch_size
niter = int(np.ceil(IMG.shape[0]/nbatch))
y = np.zeros((IMG.shape[0], self.nclasses, bsize, bsize))
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi, bsize=bsize,
augment=augment, tile_overlap=tile_overlap)
ny, nx, nchan, ly, lx = IMG.shape
IMG = np.reshape(IMG, (ny*nx, nchan, ly, lx))
batch_size = self.batch_size
niter = int(np.ceil(IMG.shape[0] / batch_size))
y = np.zeros((IMG.shape[0], self.nclasses, ly, lx))
for k in range(niter):
irange = np.arange(nbatch*k, min(IMG.shape[0], nbatch*k+nbatch))
irange = np.arange(batch_size*k, min(IMG.shape[0], batch_size*k+batch_size))
y0, style = self.net(nd.array(IMG[irange], ctx=self.device))
y0 = y0.asnumpy()
y[irange] = y0
if k==0:
styles = style.asnumpy()[0]
styles += style.asnumpy().sum(axis=0)

styles /= IMG.shape[0]
if augment:
y = np.reshape(y, (ny, nx, self.nclasses, bsize, bsize))
Expand All @@ -672,7 +702,8 @@ def _run_tiled(self, imgi, augment=False, bsize=224):
return yf, styles

def _run_3D(self, imgs, rsz=1.0, anisotropy=None, net_avg=True,
augment=False, tile=True, bsize=224, progress=None):
augment=False, tile=True, tile_overlap=0.1,
bsize=224, progress=None):
""" run network on stack of images
(faster if augment is False)
Expand All @@ -698,6 +729,9 @@ def _run_3D(self, imgs, rsz=1.0, anisotropy=None, net_avg=True,
tiles image to ensure GPU/CPU memory usage limited (recommended);
cannot be turned off for 3D segmentation
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
bsize: int (optional, default 224)
size of tiles to use in pixels [bsize x bsize]
Expand Down Expand Up @@ -733,7 +767,8 @@ def _run_3D(self, imgs, rsz=1.0, anisotropy=None, net_avg=True,
xsl = transforms.resize_image(xsl, rsz=rescaling[p])
# per image
print('\n running %s: %d planes of size (%d, %d) \n\n'%(sstr[p], shape[0], shape[1], shape[2]))
y, style = self._run_nets(xsl, net_avg=net_avg, augment=augment, tile=tile, bsize=bsize)
y, style = self._run_nets(xsl, net_avg=net_avg, augment=augment, tile=tile,
bsize=bsize, tile_overlap=tile_overlap)
y = transforms.resize_image(y, shape[1], shape[2])
yf[p] = y.transpose(ipm[p])
if progress is not None:
Expand Down Expand Up @@ -1037,7 +1072,7 @@ def __init__(self, gpu=False, pretrained_model=False,


def eval(self, imgs, batch_size=8, channels=None, normalize=True, invert=False, rescale=None,
do_3D=False, anisotropy=None, net_avg=True, augment=False, tile=True,
do_3D=False, anisotropy=None, net_avg=True, augment=False, tile=True, tile_overlap=0.1,
resample=False, flow_threshold=0.4, cellprob_threshold=0.0, compute_masks=True,
min_size=15, stitch_threshold=0.0, progress=None):
"""
Expand Down Expand Up @@ -1084,6 +1119,12 @@ def eval(self, imgs, batch_size=8, channels=None, normalize=True, invert=False,
tile: bool (optional, default True)
tiles image to ensure GPU/CPU memory usage limited (recommended)
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
resample: bool (optional, default False)
run dynamics at original image size (will be slower but create more accurate boundaries)
flow_threshold: float (optional, default 0.4)
flow error threshold (all cells with errors below threshold are kept) (not used for 3D)
Expand Down Expand Up @@ -1152,7 +1193,8 @@ def eval(self, imgs, batch_size=8, channels=None, normalize=True, invert=False,
# rescale image for flow computation
img = transforms.resize_image(img, rsz=rescale[i])
y, style = self._run_nets(img, net_avg=net_avg,
augment=augment, tile=tile)
augment=augment, tile=tile,
tile_overlap=tile_overlap)
net_time += time.time() - tic
if progress is not None:
progress.setValue(55)
Expand Down Expand Up @@ -1192,7 +1234,8 @@ def eval(self, imgs, batch_size=8, channels=None, normalize=True, invert=False,
tic=time.time()
shape = x[i].shape
yf, style = self._run_3D(x[i], rsz=rescale[i], anisotropy=anisotropy,
net_avg=net_avg, augment=augment, tile=tile, progress=progress)
net_avg=net_avg, augment=augment, tile=tile,
tile_overlap=tile_overlap, progress=progress)
cellprob = yf[0][-1] + yf[1][-1] + yf[2][-1]
dP = np.stack((yf[1][0] + yf[2][0], yf[0][0] + yf[2][1], yf[0][1] + yf[1][1]),
axis=0) # (dZ, dY, dX)
Expand Down
59 changes: 32 additions & 27 deletions cellpose/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def average_tiles(y, ysub, xsub, Ly, Lx):
yf /= Navg
return yf

def make_tiles(imgi, bsize=224, augment=True):
def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
""" make tiles of image to run at test-time
there are 4 versions of tiles
if augmented, tiles are flipped and tile_overlap=2.
* original
* flipped vertically
* flipped horizontally
Expand All @@ -99,6 +99,15 @@ def make_tiles(imgi, bsize=224, augment=True):
imgi : float32
array that's nchan x Ly x Lx
bsize : float (optional, default 224)
size of tiles
augment : bool (optional, default False)
flip tiles and set tile_overlap=2.
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles
Returns
-------
IMG : float32
Expand All @@ -110,27 +119,19 @@ def make_tiles(imgi, bsize=224, augment=True):
xsub : list
list of arrays with start and end of tiles in X of length ntiles
Ly : int
size of total image pre-tiling in Y (may be larger than original image if
image size is less than bsize)
Lx : int
size of total image pre-tiling in X (may be larger than original image if
image size is less than bsize)
"""

bsize = np.int32(bsize)
nchan, Ly0, Lx0 = imgi.shape
# pad if image smaller than bsize
if Ly0<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan,bsize-Ly0, Lx0))), axis=1)
Ly0 = bsize
if Lx0<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan,Ly0, bsize-Lx0))), axis=2)
Ly, Lx = imgi.shape[-2:]

nchan, Ly, Lx = imgi.shape
if augment:
bsize = np.int32(bsize)
# pad if image smaller than bsize
if Ly<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, bsize-Ly, Lx))), axis=1)
Ly = bsize
if Lx<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize-Lx))), axis=2)
Ly, Lx = imgi.shape[-2:]
# tiles overlap by half of tile size
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
Expand All @@ -155,19 +156,23 @@ def make_tiles(imgi, bsize=224, augment=True):
elif j%2==1 and i%2==1:
IMG[j,i] = IMG[j,i,:, ::-1, ::-1]
else:
tile_overlap = min(0.5, max(0.05, tile_overlap))
bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
bsizeY = np.int32(bsizeY)
bsizeX = np.int32(bsizeX)
# tiles overlap by 10% tile size
ny = 1 if Ly<=bsize else int(np.ceil(1.2 * Ly / bsize))
nx = 1 if Lx<=bsize else int(np.ceil(1.2 * Lx / bsize))
ystart = np.linspace(0, Ly-bsize, ny).astype(int)
xstart = np.linspace(0, Lx-bsize, nx).astype(int)
ny = 1 if Ly<=bsize else int(np.ceil((1.+2*tile_overlap) * Ly / bsize))
nx = 1 if Lx<=bsize else int(np.ceil((1.+2*tile_overlap) * Lx / bsize))
ystart = np.linspace(0, Ly-bsizeY, ny).astype(int)
xstart = np.linspace(0, Lx-bsizeX, nx).astype(int)

ysub = []
xsub = []
IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32)
IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32)
for j in range(len(ystart)):
for i in range(len(xstart)):
ysub.append([ystart[j], ystart[j]+bsize])
xsub.append([xstart[i], xstart[i]+bsize])
ysub.append([ystart[j], ystart[j]+bsizeY])
xsub.append([xstart[i], xstart[i]+bsizeX])
IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]

return IMG, ysub, xsub, Ly, Lx
Expand Down

0 comments on commit d6cd00f

Please sign in to comment.