In [None]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from fastai.vision.models.xresnet import *
from fastai.vision.models.unet import DynamicUnet
import pandas as pd

from bpho import *

In [None]:
datasetname = 'foo_003'
data_path = Path('.')
datasets = data_path/'datasets'
datasources = data_path/'data'
dataset = datasets/datasetname

test_files = dataset/'test'
hr_tifs = dataset/'hr'
lr_tifs = dataset/'lr'
lr_up_tifs = dataset/'lr_up'

hr_multi_tifs = dataset/'hr_mt_05_tiles_0512'
lr_up_multi_tifs = dataset/'lrup_mt_05_tiles_0512'

mname = 'combo_multi_tile'
model_dir = 'models'

In [None]:
torch.cuda.set_device(1)

In [None]:
def get_src(x_data, y_data_):
    def map_to_hr(x):
        hr_name = x.relative_to(x_data)
        return y_data_/hr_name
    
    src = (MultiImageImageList
            .from_folder(x_data, extensions=['.npy'])
            .split_by_rand_pct()
            .label_from_func(map_to_hr, label_cls=NpyRawImageList))
    return src


def get_data(bs, size, x_data, y_data, max_zoom=1.1):
    src = get_src(x_data, y_data)
    tfms = get_transforms(flip_vert=True, max_lighting=None, max_zoom=max_zoom)
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size)
            .databunch(bs=bs))
    data.c = 3
    return data


def do_fit(learn, save_name, lrs=slice(1e-3), pct_start=0.9, cycle_len=10):
    learn.to_fp16().fit_one_cycle(cycle_len, lrs, pct_start=pct_start)
    learn.save(save_name)
    print(f'saved: {save_name}')
    num_rows = min(learn.data.batch_size, 3)
    learn.to_fp32().show_results(rows=num_rows, imgsize=5)


def get_model(in_c, out_c, arch):
    body = nn.Sequential(*list(arch(c_in=in_c).children())[:-2])
    model = DynamicUnet(
        body, n_classes=out_c,
        blur=True, blur_final=True,
        self_attention=True, norm_type=NormType.Weight, 
        last_cross=True, bottle=True
    )
    return model


In [None]:
step = 0
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 4
size = 256
arch = xresnet50

data = get_data(bs, size, lr_up_multi_tifs, hr_multi_tifs)
learn = xres_unet_learner(data, arch, in_c=5, loss_func=loss, metrics=metrics, model_dir=model_dir)
gc.collect()

In [None]:
data.show_batch(3)

In [None]:
if False:
    learn.lr_find()
    learn.recorder.plot()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
learn.save('huh')

In [None]:
step = 1
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 4
size = 512
arch = xresnet50

data = get_data(bs, size, lr_up_multi_tifs, hr_multi_tifs)
learn = xres_unet_learner(data, arch, in_c=5, loss_func=loss, metrics=metrics, model_dir=model_dir)
#learn.load(f'{mname}.{(step-1):02d}')
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
step = 2
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 2
size = 1024
max_zoom = 2
arch = xresnet34

data = get_data(bs, size, lr_up_multi_tifs, hr_multi_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, in_c=5, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'{mname}.{(step-1):02d}')
gc.collect()

In [None]:
do_fit(learn, f'{mname}.{step:02d}', lrs=lr, cycle_len=cycles)

In [None]:
list((test_files/'mitotracker').iterdir())

In [None]:
fns = []
fns += list(test_files.glob('**/*.czi')) 
fns += list(test_files.glob('**/*.tif'))
print(fns)
items = []
movies = set()
for fn in progress_bar(fns):
    parts = fn.stem.split('_')
    group = fn.relative_to(test_files).parts[0]
    items.append(dict(fn=str(fn),group=group))
    movies.add((group, fn))

df = pd.DataFrame(items)
df.head()

In [None]:
movie_files = [Path(fn) for fn in list(df.fn.values)]
len(movie_files)

In [None]:
step = 0
lr = 1e-4
cycles = 2
loss = F.mse_loss
metrics = sr_metrics


bs = 1
size = 256
max_zoom = 1
arch = xresnet50

data = get_data(bs, size, lr_up_multi_tifs, hr_multi_tifs, max_zoom=max_zoom)
learn = xres_unet_learner(data, arch, in_c=5, loss_func=loss, metrics=metrics, model_dir=model_dir)
learn.load(f'huh').to_fp16()
gc.collect()

In [None]:
generate_movies(movie_files, learn, size, wsize=5)