In [None]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from torchvision.models import vgg16_bn
import PIL
import imageio
from superres import *
from superres.helpers import czi_predict_movie
import skimage
from skimage.util import random_noise
from skimage import filters
from scipy.ndimage.interpolation import zoom as npzoom

In [None]:
img_data = Path('/scratch/bpho/datasets/synth_newcrap_001/')
model_path = Path('/scratch/bpho/models')
nb_name = "synth_newcrap_0001_unet_mse_stability"

In [None]:

from numbers import Integral

class MultiImage(ItemBase):
    def __init__(self, img_list):
        self.img_list = img_list

    def __repr__(self):
        return f'MultiImage: {[str(img) for img in self.img_list]}'

    @property
    def size(self):
        return [img.size for img in self.img_list]
    
    @property
    def data(self):
        img_data = torch.stack([img.data for img in self.img_list])
        data = tensor(img_data)
        return data

    def apply_tfms(self, tfms, **kwargs):
        first_time = True

        save_img_list = []
        for img in self.img_list:
            new_img = img.apply_tfms(tfms, do_resolve=first_time, **kwargs)
            first_time = False
            save_img_list.append(new_img)
        self.img_list = save_img_list
        return self
    
    def _repr_png_(self): return self._repr_image_format('png')
    def _repr_jpeg_(self): return self._repr_image_format('jpeg')

    def _repr_image_format(self, format_str):
        #return self.img_lists[0]._repr_image_format(format_str)
        with BytesIO() as str_buffer:
            img_data = np.concatenate([image2np(img.px) for img in self.img_list], axis=1)
            plt.imsave(str_buffer, img_data, format=format_str)
            return str_buffer.getvalue()

    def show(self, **kwargs):
        self.img_list[0].show(**kwargs)

        
class MultiImageList(ImageList):
    "`ItemList` suitable for computer vision."
    _bunch,_square_show,_square_show_res = ImageDataBunch,True,True
    def __init__(self, *args, map_fns=None,  **kwargs):
        super().__init__(*args, **kwargs)
        if map_fns is None: map_fns = [lambda x: x]
        self.map_fns =  map_fns
        
    def open(self, fn):
        "Open image in `fn`, subclass and overwrite for custom behavior."
        fns = [map_fn(fn) for map_fn in self.map_fns]
        img_lists = [open_image(fn, convert_mode=self.convert_mode, after_open=self.after_open) for fn in fns]
        return MultiImage(img_lists)
    
    def __getitem__(self,idxs:int)->Any:
        idxs = try_int(idxs)
        if isinstance(idxs, Integral): return self.get(idxs)
        else: return self.new(self.items[idxs], inner_df=index_row(self.inner_df, idxs), map_fns=self.map_fns)
        
    def reconstruct(self, t:Tensor):
        n, c, h, w = t.shape
        one_img = t.view(c,n*h,w)
        return Image(one_img.float().clamp(min=0,max=1))

                

class MultiImageImageList(MultiImageList):
    "`ItemList` suitable for `Image` to `Image` tasks."
    _label_cls,_square_show,_square_show_res = ImageList,False,False

    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show the `xs` (inputs) and `ys`(targets)  on a figure of `figsize`."
        axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
        for i, (x,y) in enumerate(zip(xs,ys)):
            x.show(ax=axs[i,0], **kwargs)
            y.show(ax=axs[i,1], **kwargs)
        plt.tight_layout()

    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
        title = 'Input / Prediction / Target'
        axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
        for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
            x.show(ax=axs[i,0], **kwargs)
            y.show(ax=axs[i,2], **kwargs)
            z.show(ax=axs[i,1], **kwargs)
            
class MultiXModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        all_ys = []
        for i in range(x.shape[1]):
            all_ys.append(self.model(x[:,i]))
        ys = torch.stack(all_ys, dim=1)
        return ys
    
class MultiYStabilityLoss(nn.Module):
    def __init__(self, base_loss=F.mse_loss, stable_wt=0.15):
        super().__init__()
        self.base_loss = base_loss
        self.stable_wt = stable_wt
        self.base_loss_wt = (1-stable_wt)
        self.metric_names = ['pixel','stable','ssim','psnr']

    def forward(self, input, target):
        num_y = input.shape[1]
        base_loss = 0.
        stable_loss = 0.
        last_y = None
        for i in range(num_y):
            y = input[:,i]
            base_loss += self.base_loss(y, target)
            if not last_y is None:
                stable_loss += F.mse_loss(last_y, y)
            last_y = y
            
        base_loss /= num_y
        stable_loss /= (num_y-1)
        loss = base_loss * self.base_loss_wt + self.stable_wt * stable_loss
        self.metrics = {
            'pixel': base_loss,
            'stable': stable_loss,
            'ssim': ssim.ssim(last_y, target),
            'psnr': psnr(last_y, target)
        }
        return loss

In [None]:
def get_src(size=128):
    hr_tifs = img_data/f'hr'
    lr_tifs = img_data/f'lr'

    def map_to_hr(x):
        hr_name = x.relative_to(lr_tifs)
        return hr_tifs/hr_name
    print(lr_tifs)
    src = (MultiImageImageList
            .from_folder(lr_tifs, map_fns=[lambda x:x, lambda x:x])
            .split_by_folder()
            .label_from_func(map_to_hr))
    return src


def new_crappify(x, scale=4):
    c,h,w = x.shape
    def apply_crap(x):
        h,w = x.shape
        x = npzoom(x, 1/scale, order=1)
        x = random_noise(x, mode='salt', amount=0.005)
        x = random_noise(x, mode='pepper', amount=0.005)
        lvar = filters.gaussian(x, sigma=1) + 1e-6
        x = random_noise(x, mode='localvar', local_vars=lvar*0.05)
        x = npzoom(x, scale, order=1)
        return x.reshape(1,h,w).astype(np.float32)
    
    x1 = apply_crap(x[0,:,:].numpy().reshape(h,w))
    x2 = apply_crap(x[0,:,:].numpy().reshape(h,w))
    x_out = np.stack([x1,x2])
    return tensor(x_out)

crappify = TfmPixel(new_crappify)


def get_data(bs, size, tile_size=None, noise=None, max_zoom=4., scale=4):
    if tile_size is None: tile_size = size
    src = get_src(tile_size)
    
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    y_tfms = [[t for t in tfms[0]], [t for t in tfms[1]]]

    #tfms[0].insert(0, crappify(scale=scale))
    #tfms[0].append(crappify(scale=scale))
    data = (src
            .transform(tfms, size=size)
            .transform_y(y_tfms, size=size)
            .databunch(bs=bs).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

In [None]:
bs = 4
size = 128
data = get_data(bs, size, max_zoom=4)

In [None]:
data.show_batch(3)    

In [None]:
class FrameStableLoss(nn.Module):
    def __init__(self, model, base_loss=F.mse_loss, stable_wt=0.15):
        super().__init__()
        self.model = model
        self.base_loss = base_loss
        self.stable_wt = stable_wt
        self.base_loss_wt = (1-stable_wt)/2
        if hasattr(base_loss, 'metric_names'):
            metric_names = base_loss.metric_names
        else:
            metric_names = []

    def forward(self, input, target):
        base_loss = self.base_loss
        
        if self.model.training:
            num_chan = input.shape[1] // 2
            y1 = input[:,0:num_chan,:,:]
            y2 = input[:,num_chan:,:,:]
            base_2 = base_loss(y2, target)
            base_1 = base_loss(y1, target)
            stable_err = F.mse_loss(y1,y2)
            loss = (base_1 * self.base_loss_wt +
                    base_2 * self.base_loss_wt +
                    stable_err * self.stable_wt)
        else:
            loss = base_loss(input, target)
            
        if hasattr(self.base_loss, 'metrics'):
            self.metrics = self.base_loss.metrics
        
        return loss
    
class FrameStableModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def noise(self, x, pct_destroy=0.15, shift=0.3):
        x1 = x[:,0,:,:]
        b,c,h,w = x1.shape
        perm1 = torch.randperm(c*b*h*w)
        perm2 = torch.randperm(c*b*h*w)
        idx1 = perm1[:int(pct_destroy*perm1.shape[0])]
        idx2 = perm2[:int(pct_destroy*perm2.shape[0])]
        x.view(-1)[idx1] *= shift
        x.view(-1)[idx2] *= -shift
        return x1.repeat((1,3,1,1))
    
    def forward(self, x):
        if self.training:
            x1, x2 = self.noise(x), self.noise(x)
            y1 = self.model(x1)
            y2 = self.model(x2)
            return torch.cat((y1,y2))
        else:
            return self.model(x)
    
    def get_old_model(self):
        return self.model

In [None]:
arch = models.resnet18
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss,
                     metrics=superres_metrics, #callback_fns=[VideoStability], 
                     blur=True, blur_final=True, norm_type=NormType.Weight, 
                     self_attention=True, last_cross=True, bottle=True,
                     #y_range=(0.,1.),
                     model_dir=model_path)
gc.collect()

In [None]:
learn.model = MultiXModel(learn.model)
learn.loss_func = MultiYStabilityLoss(learn.loss_func)

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
#%debug

In [None]:
lr = 1e-3

In [None]:
def do_fit(save_name, lrs=slice(lr), pct_start=0.9, cycle_len=10):
    orig_model = learn.model
    orig_loss = learn.loss_func
    #learn.model = FrameStableModel(orig_model)
    #learn.loss_func = FrameStableLoss(learn.model, orig_loss)
    learn.fit_one_cycle(cycle_len, lrs, pct_start=pct_start)
    #learn.model = learn.model.get_old_model()
    #learn.loss_func = orig_loss
    learn.save(save_name)
    num_rows = min(learn.data.batch_size, 3)
    learn.show_results(rows=num_rows, imgsize=5)

In [None]:
do_fit(f'{nb_name}.0', lr, cycle_len=6)

In [None]:
learn.unfreeze()

In [None]:
do_fit(f'{nb_name}.1', slice(1e-5,lr))

In [None]:
do_fit(f'{nb_name}.2', lr/100, cycle_len=4)

In [None]:
learn.unfreeze()

In [None]:
do_fit(f'{nb_name}.3', slice(1e-5,lr/10), cycle_len=4)

In [None]:
print('cool')

In [None]:
bs = 4
size = 512
data = get_data(bs, size, max_zoom=2.)

arch = models.resnet18
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, 
                     metrics=superres_metrics, #callback_fns=LossMetrics, 
                     blur=True, blur_final=True, norm_type=NormType.Weight, 
                     self_attention=True, last_cross=True, bottle=True,
                     #y_range=(0.,1.),
                     model_dir=model_path)

learn = learn.load(f'{nb_name}.3')
gc.collect()

In [None]:
do_fit(f'{nb_name}.4', lr/100)

In [None]:
learn.unfreeze()

In [None]:
do_fit(f'{nb_name}.5', slice(1e-6,lr/100))

In [None]:
bs = 2
size = 1024
data = get_data(bs, size, max_zoom=2.)

arch = models.resnet18
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, 
                     metrics=superres_metrics, callback_fns=LossMetrics, 
                     blur=True, blur_final=True, norm_type=NormType.Weight, 
                     self_attention=True, last_cross=True, bottle=True,
                     #y_range=(0.,1.),
                     model_dir=model_path)
gc.collect()

learn = learn.load(f'{nb_name}.5')

In [None]:
do_fit(f'{nb_name}.6', lr/100, cycle_len=4)

In [None]:
learn.unfreeze()

In [None]:
do_fit(f'{nb_name}.7', slice(1e-6,lr/100), cycle_len=2)

In [None]:
print('cool')

In [None]:
!ls /scratch/bpho/models/{nb_name}*

In [None]:
#movie_files = list(Path('/scratch/bpho/datasets/movies_001/test').glob('*.czi'))
movie_files = list(Path('/scratch/bpho/datasources/low_res_test/').glob('low res confocal*.czi'))

In [None]:
fn = movie_files[0]
len(movie_files)

In [None]:
with czifile.CziFile(fn) as czi_f:
    proc_axes, proc_shape = get_czi_shape_info(czi_f)
    channels = proc_shape['C']
    depths = proc_shape['Z']
    times = proc_shape['T']
    x,y = proc_shape['X'], proc_shape['Y']
    data = czi_f.asarray()
    preds = []
    origs = []
    idx = build_index(proc_axes, {'T': 0, 'C': 0, 'Z':0, 'X':slice(0,x),'Y':slice(0,y)})
    img = data[idx].astype(np.float32)
    img /= (img.max() * 1.5)

In [None]:
def image_from_tiles(learn, img, tile_sz=128, scale=4):
    pimg = PIL.Image.fromarray((img*255).astype(np.uint8), mode='L').convert('RGB')
    cur_size = pimg.size
    new_size = (cur_size[0]*scale, cur_size[1]*scale)
    in_img = Image(pil2tensor(pimg.resize(new_size, resample=PIL.Image.BICUBIC),np.float32).div_(255))
    c, w, h = in_img.shape
    
    in_tile = torch.zeros((c,tile_sz,tile_sz))
    out_img = torch.zeros((c,w,h))
    
    for x_tile in range(math.ceil(w/tile_sz)):
        for y_tile in range(math.ceil(h/tile_sz)):
            x_start = x_tile

            x_start = x_tile*tile_sz
            x_end = min(x_start+tile_sz, w)
            y_start = y_tile*tile_sz
            y_end = min(y_start+tile_sz, h)
            
            
            in_tile[:,0:(x_end-x_start), 0:(y_end-y_start)] = in_img.data[:,x_start:x_end, y_start:y_end]
            
            out_tile,_,_ = learn.predict(Image(in_tile))

            out_x_start = x_start
            out_x_end = x_end
            out_y_start = y_start
            out_y_end = y_end

            #print("out: ", out_x_start, out_y_start, ",", out_x_end, out_y_end)
            in_x_start = 0
            in_y_start = 0
            in_x_end = x_end-x_start
            in_y_end = y_end-y_start
            #print("tile: ",in_x_start, in_y_start, ",", in_x_end, in_y_end)
           
            out_img[:,out_x_start:out_x_end, out_y_start:out_y_end] = out_tile.data[:,
                                                                                  in_x_start:in_x_end, 
                                                                                  in_y_start:in_y_end]
    return out_img


In [None]:
def czi_predict_movie(learn, czi_in, orig_out='orig.tif', pred_out='pred.tif', size=128):
    with czifile.CziFile(czi_in) as czi_f:
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        times = proc_shape['T']
        x,y = proc_shape['X'], proc_shape['Y']
        data = czi_f.asarray()
        preds = []
        origs = []
        img_max = None
        for t in progress_bar(list(range(times))):
            idx = build_index(proc_axes, {'T': t, 'C': 0, 'Z':0, 'X':slice(0,x),'Y':slice(0,y)})
            img = data[idx].astype(np.float32)
            if img_max is None: img_max = img.max() * 1.0
            img /= img_max
            out_img = image_from_tiles(learn, img, tile_sz=size).permute([1,2,0])
            pred = (out_img[None]*255).cpu().numpy().astype(np.uint8)
            preds.append(pred)
            orig = (img[None]*255).astype(np.uint8)
            origs.append(orig)

        all_y = np.concatenate(preds)
        #print(all_y.shape)
        imageio.mimwrite(pred_out, all_y) #, fps=30, macro_block_size=None) # for mp4
        all_y = np.concatenate(origs)
        #print(all_y.shape)
        imageio.mimwrite(orig_out, all_y) #, fps=30, macro_block_size=None)


In [None]:
bs=1
size=1024
scale = 4

data = get_data(bs, size, tile_size=128)

arch = models.resnet18
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, 
                     metrics=superres_metrics, callback_fns=LossMetrics, 
                     blur=True, blur_final=True, norm_type=NormType.Weight, 
                     self_attention=True, last_cross=True, bottle=True,
                     #y_range=(0.,1.),
                     model_dir=model_path)
gc.collect()
learn = learn.load(f'{nb_name}.5')


In [None]:
#learn.export(model_path/'paired_001_unet.8.pkl')
#learn = load_learner(model_path, 'paired_001_unet.8.pkl')

In [None]:
for fn in movie_files:
    pred_name = f'{fn.stem}_pred.tif'
    orig_name = f'{fn.stem}_orig.tif'
    czi_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name )

In [None]:
learn.pred_batch()

In [None]:
import skimage.util as u

In [None]:
u.img_as_ubyte

In [None]:
torch.stack?

In [None]:
import pytorch_ssim

In [None]:
pytorch_ssim.ssim??