In [1]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from superres import superres_metrics, ssim, psnr
from superres import RRDB_Net, MultiImageToMultiChannel, TwoXModel, TwoYLoss, MultiToMultiImageList, TransformableLists
import czifile
from superres.helpers import get_czi_shape_info, build_index
import imageio
import adabound

datasetname = 'multiframe_001'
data_path = Path('/scratch/bpho')
datasets = data_path/'datasets'
datasources = data_path/'datasources'
dataset = datasets/datasetname

hr_path = dataset/'hr'
lr_path = dataset/'lr'
lr_up_path = dataset/'lr_up'

In [2]:
def map_to_hr(x):
    hr_name = x.relative_to(lr_path)
    return hr_path/hr_name

def get_src(size=128, scale=1):
    if scale != 1: use_lr_path = lr_path
    else: use_lr_path = lr_up_path
        
    def map_to_hr(x):
        hr_name = x.relative_to(use_lr_path)
        return hr_path/hr_name
    
    src = (MultiToMultiImageList.from_folder(use_lr_path, extensions=['.npy'])
           .split_by_folder()
           .label_from_func(map_to_hr))
    return src

def get_data(bs, size, scale=1, max_zoom=1.):
    src = get_src(size, scale=scale)    
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size*scale)
            .databunch(bs=bs).normalize(do_y=True))
    return data

In [3]:
bs = 4
size = 256
data = get_data(bs, size, scale=4, max_zoom=1.2)

in_c = 3
out_c = 1
nf = 64
gcval = 32 # 32
nb = 8 # 23
loss = TwoYLoss(stable_wt=0.1)
model = TwoXModel(MultiImageToMultiChannel(RRDB_Net(3, 1, nf, nb, gc=gcval )))
model = nn.DataParallel(model)
#opt_func = partial(adabound.AdaBound,  lr=1e-6)

In [None]:
learn = Learner(data, model, callback_fns=LossMetrics, loss_func=loss) #, opt_func=opt_func)
learn = learn.to_fp16(loss_scale=64).load('rrdb.2')
gc.collect()

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(1, max_lr=2e-6)
#learn.fit(4, m)

In [None]:
#learn.save('rrdb.3')


In [4]:
movie_files = []
#movie_files = list(Path('/scratch/bpho/datasets/movies_001/test').glob('*.czi'))
#movie_files += list(Path('/scratch/bpho/datasources/low_res_test/').glob('low res confocal*.czi'))

movie_files += list(Path('/scratch/bpho/datasources/neuron_movies/').glob('low*.*'))

In [5]:
movie_files

[PosixPath('/scratch/bpho/datasources/neuron_movies/low res 1500 time points 8 redo best.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 6.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 1200 time points 11 high signal.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 4 aligned.tif'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 1500 time points 7.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 2 aligned.tif'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 5.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 1200 time points 9 high signal.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 3.czi'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 3 aligned.tif'),
 PosixPath('/scratch/bpho/datasources/neuron_movies/low res 300 time points 5-1 aligned

In [6]:
def image_from_tiles(learn, in_img, tile_sz=128, scale=4):
    cur_size = in_img.shape[1:3]
    c = in_img.shape[0]
    new_size = (cur_size[0]*scale, cur_size[1]*scale)
    w, h = cur_size
    
    in_tile = torch.zeros((c,tile_sz,tile_sz))
    out_img = torch.zeros((1,w*scale,h*scale))
    
    for x_tile in range(math.ceil(w/tile_sz)):
        for y_tile in range(math.ceil(h/tile_sz)):
            x_start = x_tile

            x_start = x_tile*tile_sz
            x_end = min(x_start+tile_sz, w)
            y_start = y_tile*tile_sz
            y_end = min(y_start+tile_sz, h)
            
            
            in_tile[:,0:(x_end-x_start), 0:(y_end-y_start)] = tensor(in_img[:,x_start:x_end, y_start:y_end])
            
            img_list = [Image(in_tile[i][None]) for i in range(3)]
            tlist = TransformableLists([img_list, img_list])
            
            out_tile,_,_ = learn.predict(tlist)
            
            out_x_start = x_start * scale
            out_x_end = x_end * scale
            out_y_start = y_start * scale
            out_y_end = y_end * scale

            #print("out: ", out_x_start, out_y_start, ",", out_x_end, out_y_end)
            in_x_start = 0
            in_y_start = 0
            in_x_end = (x_end-x_start) * scale
            in_y_end = (y_end-y_start) * scale
            #print("tile: ",in_x_start, in_y_start, ",", in_x_end, in_y_end)
           
            out_img[:,out_x_start:out_x_end, out_y_start:out_y_end] = out_tile.data[:,
                                                                                  in_x_start:in_x_end, 
                                                                                  in_y_start:in_y_end]
    return out_img

In [7]:
def tif_predict_movie(learn, tif_in, orig_out='orig.tif', pred_out='pred.tif', size=128, wsize=3):
    im = PIL.Image.open(tif_in)
    im.load()
    times = im.n_frames
    #times = min(times,100)
    imgs = []
    for i in range(times):
        im.seek(i)
        imgs.append(np.array(im).astype(np.float32)/255.)
    img_data = np.stack(imgs)
    
    def pull_frame(i):
        im.seek(i)
        im.load()
        return np.array(im)
    
    preds = []
    origs = []
    img_max = img_data.max()
    print(img_max)
    print('max: ', img_max)
    for t in progress_bar(list(range(0,times-wsize+1))):
        img = img_data[t:(t+wsize)].copy()
        img /= img_max
        
        out_img = image_from_tiles(learn, img, tile_sz=size)
        pred = (out_img*255).cpu().numpy().astype(np.uint8)
        preds.append(pred)
        orig = (img[1][None]*255).astype(np.uint8)
        origs.append(orig)

    all_y = np.concatenate(preds)
    #print(all_y.shape)
    imageio.mimwrite(pred_out, all_y) #, fps=30, macro_block_size=None) # for mp4
    all_y = np.concatenate(origs)
    #print(all_y.shape)
    imageio.mimwrite(orig_out, all_y) #, fps=30, macro_block_size=None)


def czi_predict_movie(learn, czi_in, orig_out='orig.tif', pred_out='pred.tif', size=128, wsize=3):
    with czifile.CziFile(czi_in) as czi_f:
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        times = proc_shape['T']
        #times = min(times, 100)
        x,y = proc_shape['X'], proc_shape['Y']
        
        
        data = czi_f.asarray().astype(np.float32)/255.
        preds = []
        origs = []
        
        img_max = data.max()
        print(img_max)
        for t in progress_bar(list(range(0,times-wsize+1))):
            idx = build_index(proc_axes, {'T': slice(t,t+wsize), 'C': 0, 'Z':0, 'X':slice(0,x),'Y':slice(0,y)})
            img = data[idx].copy()
            img /= img_max
            
            out_img = image_from_tiles(learn, img, tile_sz=size)
            pred = (out_img*255).cpu().numpy().astype(np.uint8)
            preds.append(pred)
            orig = (img[1][None]*255).astype(np.uint8)
            origs.append(orig)
            
        all_y = np.concatenate(preds)
        #print(all_y.shape)
        imageio.mimwrite(pred_out, all_y) #, fps=30, macro_block_size=None) # for mp4
        all_y = np.concatenate(origs)
        #print(all_y.shape)
        imageio.mimwrite(orig_out, all_y) #, fps=30, macro_block_size=None)
        

In [8]:
bs = 4
size = 380
data = get_data(bs, size, scale=4)

loss = TwoYLoss()
model = TwoXModel(MultiImageToMultiChannel(RRDB_Net(3, 1, nf, nb, gc=gcval )))
model = nn.DataParallel(model) 

learn = Learner(data, model, callback_fns=LossMetrics, loss_func=loss)
learn = learn.load('rrdb.2').to_fp16()
gc.collect()

0

In [9]:
for fn in progress_bar(movie_files):
    try:
        pred_name = f'{fn.stem}_pred.tif'
        orig_name = f'{fn.stem}_orig.tif'
        if not Path(pred_name).exists():
            if fn.suffix == '.czi':
                print(f'czi {fn.stem}')
                czi_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name )
            elif fn.suffix == '.tif':
                tif_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name)
                tif_fn = fn
                print(f'tif {fn.stem}')
        else:
            print(f'skip: {fn.stem}')
    except:
        print(f'err: {fn.stem}')

czi low res 1500 time points 8 redo best
1.0


czi low res 300 time points 6
0.7254902


czi low res 1200 time points 11 high signal
1.0


0.6784314
max:  0.6784314


tif low res 300 time points 4 aligned
czi low res 1500 time points 7
0.92156863


0.27450982
max:  0.27450982


tif low res 300 time points 2 aligned
czi low res 300 time points 5
1.0


czi low res 1200 time points 9 high signal
1.0


czi low res 300 time points 3
0.87058824


0.63529414
max:  0.63529414


tif low res 300 time points 3 aligned
1.0
max:  1.0


tif low res 300 time points 5-1 aligned
czi low res 300 time points 4
0.6784314


czi low res 1200 time points 10 high signal
1.0


czi low res 1500 time points 8
0.9254902


czi low res 300 time points 2
0.29803923


czi low res 1200 time points 11 high signal aligned
err: low res 1200 time points 11 high signal aligned


In [None]:
list( range(3,5))

In [None]:
im = PIL.Image.open(tif_fn)
im.load()
im.n_frames

In [None]:
im.seek(3)
np.array(im).shape

In [None]:
images = list(PIL.ImageSequence.Iterator(im))

In [None]:
w_imgs = images[0:3]

In [None]:
img = np.stack([np.array(img) for img in w_imgs]).shape

In [None]:
w_imgs[2]

In [None]:
fn

In [None]:
!ls -hatlr *.tif

In [None]:
%debug

In [None]:
x.shape

In [None]:
doc(conv_layer)

In [None]:
%debug