In [1]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from fastai.vision.models.xresnet import *
from fastai.vision.models.unet import DynamicUnet
from bpho import *

  data = yaml.load(f.read()) or {}


In [2]:
torch.cuda.set_device(1)

In [3]:
datasetname = 'combo_001'
data_path = Path('.')
datasets = data_path/'datasets'
datasources = data_path/'data'
dataset = datasets/datasetname

test_files = dataset/'test'
hr_tifs = dataset/'hr'
lr_tifs = dataset/'lr'
lr_up_tifs = dataset/'lr_up'

mname = 'combo'
model_dir = 'models'

loss = F.mse_loss
metrics = sr_metrics

In [4]:
def get_src(x_data, y_data_):
    def map_to_hr(x):
        hr_name = x.relative_to(x_data)
        return y_data_/hr_name
    src = (ImageImageList
            .from_folder(x_data, convert_mode='L')
            .split_by_rand_pct()
            .label_from_func(map_to_hr, convert_mode='L'))
    return src


def get_data(bs, size, x_data, y_data, max_zoom=1.1, **kwargs):
    src = get_src(x_data, y_data)
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size)
            .databunch(bs=bs,**kwargs))
    data.c = 3
    return data

def do_fit(learn, save_name, lrs=slice(1e-3), pct_start=0.9, cycle_len=10):
    learn.to_fp16().fit_one_cycle(cycle_len, lrs, pct_start=pct_start)
    learn.save(save_name)
    print(f'saved: {save_name}')
    num_rows = min(learn.data.batch_size, 3)
    learn.to_fp32().show_results(rows=num_rows, imgsize=5)

In [5]:
if False:
    learn.lr_find()
    learn.recorder.plot()

In [None]:
step = 0
lr = 1e-3
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 16
size = 256
max_zoom = 6
arch = xresnet34

data = get_data(bs, size, lr_up_tifs, hr_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
step = 1
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 12
size = 512
max_zoom = 4
arch = xresnet34

data = get_data(bs, size, lr_up_tifs, hr_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'{mname}.{(step-1):02d}')
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr/50, cycle_len=cycles*5)

In [None]:
step = 2
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 2
size = 1024
max_zoom = 2
arch = xresnet34

data = get_data(bs, size, lr_up_tifs, hr_tifs, max_zoom=max_zoom, num_workers=4)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'{mname}.{(step-1):02d}')
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
do_fit(learn, f'{mname}.{step:02d}.1', lrs=lr/50, cycle_len=cycles*5)

In [6]:
test_fns = []
#test_fns += list(datasources.glob('**/test/*.tif'))
test_fns += list(datasources.glob('**/test/*.czi'))

In [7]:
len(test_fns)

54

In [8]:
#test_fns = []
#test_fns += list(test_files.glob('**/*.tif'))
#test_fns += list(test_files.glob('**/*.czi'))

In [9]:
step = 1
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 1
size = 440*4
max_zoom = 2
arch = xresnet34

data = get_data(bs, size, lr_up_tifs, hr_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'{mname}.{step:02d}').to_fp16()
gc.collect()

0

In [10]:
dest = Path('/DATA/Dropbox/bpho_movie_results/combo_xres_unet/')
dest.mkdir(exist_ok=True, parents=True)
generate_tifs(test_fns, dest, learn, size, tag=mname, max_imgs=10)

tif: x:512 y:512 t:7


tif: x:512 y:512 t:7


tif: x:512 y:512 t:7


tif: x:512 y:512 t:7


tif: x:512 y:512 t:7


tif: x:512 y:512 t:7


czi: x:228 y:230 t:10 c:1 z:1 7.823529243469238


czi: x:228 y:230 t:10 c:1 z:1 5.713725566864014


czi: x:228 y:230 t:10 c:1 z:1 6.356862545013428


czi: x:228 y:230 t:10 c:1 z:1 7.819607734680176


czi: x:228 y:230 t:10 c:1 z:1 7.098039150238037


czi: x:228 y:230 t:10 c:1 z:1 5.176470756530762


czi: x:380 y:380 t:10 c:1 z:1 1.0


czi: x:380 y:380 t:10 c:1 z:1 0.7254902124404907


czi: x:380 y:380 t:10 c:1 z:1 1.0


czi: x:380 y:380 t:10 c:1 z:1 0.9215686321258545


czi: x:380 y:380 t:10 c:1 z:1 1.0


czi: x:380 y:380 t:10 c:1 z:1 1.0


czi: x:440 y:440 t:10 c:1 z:1 0.3450980484485626


czi: x:440 y:440 t:10 c:1 z:1 1.0


czi: x:440 y:440 t:10 c:1 z:1 0.20392157137393951


czi: x:440 y:440 t:10 c:1 z:1 0.4274509847164154


czi: x:380 y:380 t:10 c:1 z:1 0.8705882430076599


czi: x:380 y:380 t:10 c:1 z:1 0.6784313917160034


czi: x:440 y:440 t:10 c:1 z:1 0.3137255012989044


czi: x:380 y:380 t:10 c:1 z:1 1.0


czi: x:440 y:440 t:10 c:1 z:1 0.21960784494876862


czi: x:380 y:380 t:10 c:1 z:1 0.9254902005195618


exception with low res 2000 time points 1
czi: x:440 y:440 t:10 c:1 z:1 0.2862745225429535


czi: x:380 y:380 t:10 c:1 z:1 0.2980392277240753


czi: x:440 y:440 t:10 c:1 z:1 0.4470588266849518


czi: x:256 y:256 t:1 c:3 z:28 1.0


czi: x:196 y:198 t:1 c:3 z:11 1.0


czi: x:196 y:198 t:1 c:3 z:18 1.0


czi: x:256 y:256 t:1 c:3 z:18 1.0


czi: x:316 y:318 t:1 c:1 z:23 1.0


czi: x:316 y:318 t:1 c:1 z:33 0.686274528503418


czi: x:316 y:318 t:1 c:1 z:26 1.0


czi: x:316 y:318 t:1 c:1 z:18 1.0


czi: x:400 y:400 t:1 c:1 z:24 0.23137255012989044


czi: x:316 y:318 t:1 c:1 z:30 0.5843137502670288


czi: x:316 y:318 t:1 c:1 z:27 1.0


czi: x:316 y:318 t:1 c:1 z:18 1.0


czi: x:316 y:318 t:1 c:1 z:32 1.0


czi: x:800 y:800 t:1 c:1 z:37 0.6000000238418579


czi: x:316 y:318 t:1 c:1 z:41 0.5411764979362488


czi: x:256 y:256 t:1 c:2 z:16 0.8156862854957581


czi: x:256 y:256 t:1 c:2 z:21 1.0


czi: x:256 y:256 t:1 c:2 z:24 1.0


czi: x:256 y:256 t:1 c:2 z:24 1.0


czi: x:256 y:256 t:1 c:1 z:11 1.0


czi: x:256 y:256 t:1 c:1 z:23 1.0


czi: x:256 y:256 t:1 c:1 z:15 1.0


In [None]:
#%debug