In [1]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from fastai.vision.models.xresnet import *
from fastai.vision.models.unet import DynamicUnet
from bpho import *

  data = yaml.load(f.read()) or {}


In [2]:
torch.cuda.set_device(1)

In [3]:
datasetname = 'combo_001'
data_path = Path('.')
datasets = data_path/'datasets'
datasources = data_path/'data'
dataset = datasets/datasetname

test_files = dataset/'test'
hr_tifs = dataset/'hr'
lr_tifs = dataset/'lr'
lr_up_tifs = dataset/'lr_up'

mname = 'combo'
model_dir = 'models'

loss = F.mse_loss
metrics = sr_metrics

In [4]:
def get_src(x_data, y_data_):
    def map_to_hr(x):
        hr_name = x.relative_to(x_data)
        return y_data_/hr_name
    src = (ImageImageList
            .from_folder(x_data, convert_mode='L')
            .split_by_rand_pct()
            .label_from_func(map_to_hr, convert_mode='L'))
    return src


def get_data(bs, size, x_data, y_data, max_zoom=1.1):
    src = get_src(x_data, y_data)
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size)
            .databunch(bs=bs))
    data.c = 3
    return data

def do_fit(learn, save_name, lrs=slice(1e-3), pct_start=0.9, cycle_len=10):
    learn.to_fp16().fit_one_cycle(cycle_len, lrs, pct_start=pct_start)
    learn.save(save_name)
    print(f'saved: {save_name}')
    num_rows = min(learn.data.batch_size, 3)
    learn.to_fp32().show_results(rows=num_rows, imgsize=5)

In [5]:
if False:
    learn.lr_find()
    learn.recorder.plot()

In [6]:
step = 0
lr = 1e-3
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 16
size = 256
max_zoom = 6
arch = xresnet34

data = get_data(bs, size, lr_up_tifs, hr_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
gc.collect()

0

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

epoch,train_loss,valid_loss,ssim,psnr,time


In [None]:
step = 1
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 8
size = 512
max_zoom = 4
arch = xresnet34

data = get_data(bs, size, lr_up_multi_tifs, hr_multi_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'{mname}.{(step-1):02d}')
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
step = 2
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 2
size = 1024
max_zoom = 2
arch = xresnet34

data = get_data(bs, size, lr_up_multi_tifs, hr_multi_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'{mname}.{(step-1):02d}')
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
step = 1
lr = 1e-3
cycles = 3

bs = 20
size = 256
max_zoom = 8
arch = xresnet34

data = get_data(bs, size, hr_tifs, lr_up_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
gc.collect()

if Path(f'{mname}.{(step-1):02d}').exists(): 
    print('loading', f'{mname}.{(step-1):02d}')
    learn.load(f'{mname}.{(step-1):02d}')

do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
test_fns = []
test_fns += list(datasources.glob('**/test/*.tif'))
test_fns += list(datasources.glob('**/test/*.czi'))

In [None]:
len(test_fns)

In [None]:
test_fns = []
test_fns += list(test_files.glob('**/*.tif'))
test_fns += list(test_files.glob('**/*.czi'))

In [None]:
for fn in test_fns:
    category = fn.parts[-4]
    group = fn.parts[-3]
    name = fn.stem
    break

In [None]:
category, group, name

In [None]:
fn

In [None]:
movie_files = []
movie_files = list(Path('/scratch/bpho/datasets/movies_001/test').glob('*05*.czi'))
#movie_files += list(Path('/scratch/bpho/datasources/low_res_test/').glob('low res confocal*.czi'))
#movie_files += list(Path('/scratch/bpho/datasources/neuron_movies2/').glob('low*.*'))
#movie_files += list(Path('/DATA/WAMRI/salk/uri/bpho/datasources/neuron_movies/').glob('low res 300 time points 2*.czi'))
#movie_files += list(Path('/DATA/WAMRI/salk/uri/bpho/datasources/neuron_movies/').glob('*time points 2*.tif'))
#movie_files = list(Path('/DATA/donow/').glob('*.czi'))

In [None]:
step = 1
lr = 1e-3
cycles = 3

bs = 1
size = 256*6
max_zoom = 1
arch = xresnet34

data = get_data(bs, size, hr_tifs, lr_up_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
gc.collect()

print(model_dir)
print(mname)
learn = learn.load('/home/fredmonroe/repos/salk/uri/paper/datasets/combo_001/lr_up/models/combo.1')
learn = learn.to_fp16()

In [None]:
def tif_predict_images(learn, czi_in, dest, category, tag=None, size=128):
    under_tag = f'_' if tag is None else f'_{tag}_'
    pred_out = dest_folder/f'{czi_in.stem}{under_tag}pred.tif'
    orig_out = dest_folder/f'{czi_in.stem}{under_tag}orig.tif'

    im = PIL.Image.open(tif_in)
    im.load()
    times = im.n_frames
    imgs = []


    for i in range(times):
        im.seek(i)
        im.load()
        imgs.append(np.array(im).astype(np.float32)/255.)
    img_data = np.stack(imgs)

    preds = []
    origs = []
    img_max = img_data.max()

    x,y = im.size
    print(f'tif: x:{x} y:{y} t:{times}')
    for t in progress_bar(list(range(0,times-wsize+1))):
        img = img_data[t:(t+wsize)].copy()
        img /= img_max

        out_img = unet_image_from_tiles(learn, img, tile_sz=size, wsize=wsize)
        pred = (out_img*255).cpu().numpy().astype(np.uint8)
        preds.append(pred)
        orig = (img[1][None]*255).astype(np.uint8)
        origs.append(orig)
        
    if len(preds) > 0:
        all_y = np.concatenate(preds)
        imageio.mimwrite(pred_out, all_y, bigtiff=True)
        all_y = np.concatenate(origs)
        imageio.mimwrite(orig_out, all_y, bigtiff=True)


def czi_predict_images(learn, czi_in, dest, category, tag=None, size=128):
    with czifile.CziFile(czi_in) as czi_f:
        
        under_tag = f'_' if tag is None else f'_{tag}_'
        
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        times = proc_shape['T']

        x,y = proc_shape['X'], proc_shape['Y']
        dest_folder = Path(dest/category)
        dest_folder.mkdir(exist_ok=True, parents=True)
        
        data = czi_f.asarray().astype(np.float32)/255.
        
        pred_out = dest_folder/f'{czi_in.stem}{under_tag}pred.tif'
        orig_out = dest_folder/f'{czi_in.stem}{under_tag}orig.tif'
        
        preds = []
        origs = []

        img_max = data.max()
        print(img_max)
        for t in progress_bar(list(range(0,times))):
            idx = build_index(proc_axes, {'T': t, 'C': 0, 'Z':0, 'X':slice(0,x),'Y':slice(0,y)})
            img = data[idx].copy()
            img /= img_max

            out_img = unet_image_from_tiles(learn, img, tile_sz=size, wsize=1)
            pred = (out_img*255).cpu().numpy().astype(np.uint8)
            preds.append(pred)
            #imsave(folder/f'{t}.tif', pred[0])

            orig = (img[wsize//2][None]*255).astype(np.uint8)
            origs.append(orig)
        
        if len(preds) > 0:
            all_y = np.concatenate(preds)
            imageio.mimwrite(pred_out, all_y, bigtiff=True)
            all_y = np.concatenate(origs)
            imageio.mimwrite(orig_out, all_y, bigtiff=True)



def generate_tifs(src, dest, learn, size, tag=None):
    for fn in progress_bar(src):
        category = fn.parts[-3]
        pred_name = f'{fn.stem}_pred.tif'
        orig_name = f'{fn.stem}_orig.tif'
        if not Path(pred_name).exists():
            if fn.suffix == '.czi':
                czi_predict_images(learn, fn, dest, category, size=size, tag=tag)
            elif fn.suffix == '.tif':
                tif_predict_images(learn, fn, dest, category, size=size, tag=tag)
        else:
            print(f'skip: {fn.stem} - doesn\'t exist')
