In [1]:
#import fastai
from fastai import *          # Quick access to most common functionality
from fastai.vision import *   # Quick access to computer vision functionality
from fastai.callbacks import *

In [2]:
import pytorch_ssim as ssim
from superres import *
from torchvision.models import vgg16_bn
import czifile

In [3]:
path = Path('/DATA/WAMRI/salk/uri/BPHO/')
path_processed = path/'processed'
path_hr = path/'hires'
path_lr = path/'lores'

path_hr.mkdir(exist_ok=True)
path_lr.mkdir(exist_ok=True)

In [4]:
def get_czi_shape_info(czi):
    shape = czi.shape
    axes = czi.axes
    axes_dict = {axis:idx for idx,axis in enumerate(czi.axes)}
    shape_dict = {axis:shape[axes_dict[axis]] for axis in czi.axes}
    return axes_dict, shape_dict


def build_index(axes, ix_select):
    idx = [ix_select.get(ax, 0) for ax in axes]
    return tuple(idx)


def process_czi(proc_fn):
    with czifile.CziFile(proc_fn) as proc_czf:
        proc_axes, proc_shape = get_czi_shape_info(proc_czf)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        x,y = proc_shape['X'], proc_shape['Y']
        data = proc_czf.asarray()
        for channel in range(channels):
            for depth in range(depths):
                idx = build_index(proc_axes, {'C': channel, 'Z':depth, 'X':slice(0,x),'Y':slice(0,y)})
                img = data[idx]
                save_proc_fn = path_hr/f'{proc_fn.stem}_{channel:02d}_{depth:03d}.npy'
                np.save(save_proc_fn, img)
        

In [8]:
proc_fns = list(path_processed.glob('*.czi'))
#for fn in progress_bar(proc_fns):
#    process_czi(fn)

In [100]:
def resize_one(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    
    data = np.load(fn)
    data = data.astype(float) / 8000.0
    data *= 255
    data = data.astype(np.uint8)
    
    img = PIL.Image.fromarray(data, mode='L')
    targ_sz = resize_to(img,96,use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR)
    img.save(str(dest).replace('.npy','.jpg'), quality=100)

In [101]:
hr_fns = list(path_hr.glob('*.npy'))
parallel(resize_one, hr_fns)

In [117]:
class ProcImageList(ImageImageList):
    def open(self, fn):
        data = np.load(fn)
        x = torch.from_numpy(data[None,:,:].astype(np.float32))
        x.div_(8000.)
        return Image(x)


In [118]:
def get_basename(x):
    return x.stem.split('_')[0]

base_names = list(set([get_basename(x) for x in list(path_lr.iterdir())]))
train_names, valid_names = random_split(0.15, base_names)
valid_names = list(valid_names[0])

def is_validation_basename(x):
    xbase = get_basename(x)
    return xbase in valid_names

src = (ImageItemList
       .from_folder(path_lr, label_cls=ProcImageList, extensions=".jpg", mode='L')
       .split_by_valid_func(is_validation_basename))

def get_data(src,bs,size, **kwargs):
    def lr_to_hr_fn(x):
        x_hr = path_hr/str(x.stem + ".npy")
        return x_hr
    
    data = (src.label_from_func(lr_to_hr_fn)
            .transform(get_transforms(), size=size, tfm_y=True)
            .databunch(bs=bs,**kwargs))
    data.c = 3
    return data

In [119]:
bs = 16
size = 4*96
data = get_data(src, bs, size)

In [121]:
data.train_ds[0]

(Image (3, 384, 384), Image (1, 384, 384))

In [122]:
img = data.train_ds[0][1]