In [50]:
import fastai
from fastai import *          # Quick access to most common functionality
from fastai.vision import *   # Quick access to computer vision functionality
from fastai.layers import Lambda
from fastai.callbacks import *
import pytorch_ssim as ssim
from superres import *
from torchvision.models import vgg16_bn
import czifile
import shutil
import numbers
from fastai.vision.image import TfmPixel

In [51]:
PATH = Path('/DATA/WAMRI/salk/uri/')
movie_path = PATH/'movie_src'
train_src = movie_path/'train'
valid_src = movie_path/'valid'
train_hr = movie_path/'train_hr'
train_lr = movie_path/'train_lr'
valid_hr = movie_path/'valid_hr'
valid_lr = movie_path/'valid_lr'

for folder in [train_hr, train_lr, valid_hr, valid_lr]:
     if folder.exists(): shutil.rmtree(folder)

In [52]:
def process_czi(czi_fn, hr_dir, lr_dir):
    with czifile.CziFile(czi_fn) as czi_f:
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C'] 
        depths = proc_shape['Z']
        times = proc_shape['T']
        x,y = proc_shape['X'], proc_shape['Y']
        data = czi_f.asarray()
        for channel in range(channels):
            for depth in range(depths):
                for time_col in range(times):
                    idx = build_index(proc_axes, {'T': time_col, 'C': channel, 'Z':depth, 'X':slice(0,x),'Y':slice(0,y)})
                    img = data[idx].astype(np.float)
                    save_fn = hr_dir/f'{czi_fn.stem}_{channel:02d}_{depth:03d}_{time_col:03d}.tif'
                    img -= img.min()
                    img /= img.max()
                    pimg = PIL.Image.fromarray(img)
                    pimg.save(save_fn)
                    cur_size = pimg.size
                    new_size = (cur_size[0]//4, cur_size[1]//4)
                    pimg.resize(new_size, resample=PIL.Image.BICUBIC).save(lr_dir/save_fn.name)

In [53]:
if not train_hr.exists():
    for folder in [train_hr, train_lr, valid_hr, valid_lr]: 
        folder.mkdir(exist_ok=True)
    train_files = list(train_src.glob('*.czi'))
    for fn in progress_bar(train_files): process_czi(fn, train_hr, train_lr)
    valid_files = list(valid_src.glob('*.czi'))
    for fn in progress_bar(valid_files): process_czi(fn, valid_hr, valid_lr)

In [91]:
sz_lr = 128
sz_hr = 4*sz_lr
bs = 8
num_workers = 0
kwargs = {}

def match_hr_func(x):
    return Path(str(x).replace('_lr','_hr'))

src = (GrayImageItemList
       .from_folder(movie_path, '*.tif', label_class=GrayImageItemList)
       .split_by_folder(train='train_lr', valid='valid_lr')
       .label_from_func(match_hr_func))