In [9]:
import os
import sys
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from torchvision.models import vgg16_bn
import PIL
import imageio
from utils import *
import skimage
import libtiff

In [10]:
torch.cuda.set_device(0)

In [28]:
img_data = Path('/scratch/bpho/datasets/transfer_learning_neuron_002/')
model_path = Path('/scratch/bpho/models')
model_name = "transfer_learning_neuron_002_unet_mse"
model_version = "7"
model = model_name+"."+model_version

test_path = Path('/home/bpho/Documents/test/ddd/')
test_files = list(test_path.glob('*.tif'))
test_files

results = test_path/f'{model_name}_{model_version}_pred'

if results.exists(): shutil.rmtree(results)
results.mkdir(parents=True, mode=0o775, exist_ok=True)

In [13]:
def get_src(size=128):
    hr_tifs = img_data/f'hr'
    lr_tifs = img_data/f'lr_up'

    def map_to_hr(x):
        hr_name = x.relative_to(lr_tifs)
        return hr_tifs/hr_name
    print(lr_tifs)
    src = (ImageImageList
            .from_folder(lr_tifs)
            .split_by_folder()
            .label_from_func(map_to_hr))
    return src

def _gaussian_noise_gray(x, gauss_sigma=1.):
    c,h,w = x.shape
    noise = torch.zeros((1,h,w))
    noise.normal_(0, gauss_sigma)
    img_max = np.minimum(1.1 * x.max(), 1.)
    x = np.minimum(np.maximum(0,x+noise.repeat((3,1,1))), img_max)
    return x

gaussian_noise_gray = TfmPixel(_gaussian_noise_gray)


def get_data(bs, size, tile_size=None, noise=None, max_zoom=8.):
    if tile_size is None: tile_size = size
    src = get_src(tile_size)
    
    tfms = [[rand_resize_crop(size=size)],[]]
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    y_tfms = [[t for t in tfms[0]], [t for t in tfms[1]]]
    
    if not noise is None:
        tfms[0].append(gaussian_noise_gray(gauss_sigma=noise))
    data = (src
            .transform(tfms, size=size)
            .transform_y(y_tfms, size=size)
            .databunch(bs=bs).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

In [14]:
feat_loss = F.mse_loss

In [15]:
lr = 5e-4

In [16]:
def do_fit(save_name, lrs=slice(lr), pct_start=0.9, cycle_len=10):
    learn.fit_one_cycle(cycle_len, lrs, pct_start=pct_start,
                        callbacks=[SaveModelCallback(learn, name=save_name)])
    #learn.save(save_name)
    num_rows = min(learn.data.batch_size, 3)
    learn.show_results(rows=num_rows, imgsize=5)

In [18]:
def image_from_tiles(learn, img, tile_sz=128, scale=4):
    pimg = PIL.Image.fromarray((img*255).astype(np.uint8), mode='L').convert('RGB')
    cur_size = pimg.size
    new_size = (cur_size[0]*scale, cur_size[1]*scale)
    in_img = Image(pil2tensor(pimg.resize(new_size, resample=PIL.Image.BICUBIC),np.float32).div_(255))
    c, w, h = in_img.shape
    
    in_tile = torch.zeros((c,tile_sz,tile_sz))
    out_img = torch.zeros((c,w,h))
    
    for x_tile in range(math.ceil(w/tile_sz)):
        for y_tile in range(math.ceil(h/tile_sz)):
            x_start = x_tile

            x_start = x_tile*tile_sz
            x_end = min(x_start+tile_sz, w)
            y_start = y_tile*tile_sz
            y_end = min(y_start+tile_sz, h)
            
            
            in_tile[:,0:(x_end-x_start), 0:(y_end-y_start)] = in_img.data[:,x_start:x_end, y_start:y_end]

            out_tile,_,_ = learn.predict(Image(in_tile))

            out_x_start = x_start
            out_x_end = x_end
            out_y_start = y_start
            out_y_end = y_end

            #print("out: ", out_x_start, out_y_start, ",", out_x_end, out_y_end)
            in_x_start = 0
            in_y_start = 0
            in_x_end = x_end-x_start
            in_y_end = y_end-y_start
            #print("tile: ",in_x_start, in_y_start, ",", in_x_end, in_y_end)
           
            out_img[:,out_x_start:out_x_end, out_y_start:out_y_end] = out_tile.data[:,
                                                                                  in_x_start:in_x_end, 
                                                                                  in_y_start:in_y_end]
    return out_img


def tif_predict_movie(learn, tif_in, orig_out='orig.tif', pred_out='pred.tif', size=128):
        data = libtiff.TiffFile(tif_in)
        data = data.get_tiff_array()
        depths = data.shape[0]
        img_max = None
        preds = []
        origs = []
        for depth in progress_bar(list(range(depths))):
            img = data[depth].astype(np.float32)
            if img_max is None: img_max = img.max() * 1.0
            img /= img_max
            out_img = image_from_tiles(learn, img, tile_sz=size).permute([1,2,0])
            pred = (out_img[None]*255).cpu().numpy().astype(np.uint8)
            pred_img_out = pred_out+f'_slice{depth}.tif'
            skimage.io.imsave(pred_img_out,pred[0,:,:,:])
            
            #preds.append(pred)
            #orig = (img[None]*255).astype(np.uint8)
            #origs.append(orig)

#        all_y = np.concatenate(preds)
        #print(all_y.shape)
#        imageio.mimwrite(pred_out, all_y) #, fps=30, macro_block_size=None) # for mp4
#        all_y = np.concatenate(origs)
        #print(all_y.shape)
#        imageio.mimwrite(orig_out, all_y) #, fps=30, macro_block_size=None)


def czi_predict_movie(learn, czi_in, orig_out='orig.tif', pred_out='pred.tif', size=128):
    with czifile.CziFile(czi_in) as czi_f:
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        times = proc_shape['T']
        x,y = proc_shape['X'], proc_shape['Y']
        data = czi_f.asarray()
        preds = []
        origs = []
        img_max = None
        for t in progress_bar(list(range(times))):
            idx = build_index(proc_axes, {'T': t, 'C': 0, 'Z':0, 'X':slice(0,x),'Y':slice(0,y)})
            img = data[idx].astype(np.float32)
            if img_max is None: img_max = img.max() * 1.0
            img /= img_max
            out_img = image_from_tiles(learn, img, tile_sz=size).permute([1,2,0])
            pred = (out_img[None]*255).cpu().numpy().astype(np.uint8)
            #preds.append(pred)
            orig = (img[None]*255).astype(np.uint8)
            origs.append(orig)
            pred_img_out = pred_out+f'_slice{t}.tif'
            skimage.io.imsave(pred_img_out,pred[0,:,:,:])
            
        #all_y = np.concatenate(preds)
        #print(all_y.shape)
        #imageio.mimwrite(pred_out, all_y) #, fps=30, macro_block_size=None) # for mp4
        #all_y = np.concatenate(origs)
        #print(all_y.shape)
        #imageio.mimwrite(orig_out, all_y) #, fps=30, macro_block_size=None)

In [29]:
bs=1
size=380*4
scale = 4

data = get_data(bs, size, tile_size=128)

arch = models.resnet18
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, 
                     metrics=superres_metrics, callback_fns=LossMetrics, 
                     blur=True, blur_final=True, norm_type=NormType.Weight, 
                     self_attention=True, last_cross=True, bottle=True,
                     #y_range=(0.,1.),
                     model_dir=model_path)
gc.collect()
learn = learn.load(model).to_fp16()

/scratch/bpho/datasets/transfer_learning_neuron_002/lr_up


In [None]:
for fn in test_files:
    pred_name = f'{fn.stem}_pred'
    orig_name = f'{fn.stem}_orig.tif'
    tif_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name )

In [None]:
for fn in test_files:
    #pred_name = f'{fn.stem}_pred.tif'
    pred_name = f'{fn.stem}_pred'
    orig_name = f'{fn.stem}_orig.tif'
    czi_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name )